Data Processing

We rely on TorchGeo for the implementation of datasets and data modules.

Check out the TorchGeo tutorials on datasets for more in depth information.

In general, it is reccomended you create a TorchGeo dataset specifically for your dataset. This gives you complete control and flexibility on how data is loaded, what transforms are done over it, and even how it is plotted if you log with tools like TensorBoard.

TorchGeo provides GeoDataset and NonGeoDataset.

  • If your data is already nicely tiled and ready for consumption by a neural network, you can inherit from NonGeoDataset. This is essentially a wrapper of a regular torch dataset.
  • If your data consists of large GeoTiffs you would like to sample from during training, you can leverage the powerful GeoDataset from torch. This will automatically align your input data and labels and enable a variety of geo-aware samplers.

For additional examples on fine-tuning a Terratorch model using these components, please refer to the Prithvi EO Examples repository.

Using Datasets already implemented in TorchGeo

Using existing TorchGeo DataModules is very easy! Just plug them in! For instance, to use the EuroSATDataModule, in your config file, set the data as:

data:
  class_path: torchgeo.datamodules.EuroSATDataModule
  init_args:
    batch_size: 32
    num_workers: 8
  dict_kwargs:
    root: /dccstor/geofm-pre/EuroSat
    download: True
    bands:
      - B02
      - B03
      - B04
      - B08A
      - B09
      - B10
Modifying each parameter as you see fit.

You can also do this outside of config files! Simply instantiate the data module as normal and plug it in.

Warning

To define transforms to be passed to DataModules from TorchGeo from config files, you must use the following format:

data:
class_path: terratorch.datamodules.TorchNonGeoDataModule
init_args:
  cls: torchgeo.datamodules.EuroSATDataModule
  transforms:
    - class_path: albumentations.augmentations.geometric.resize.Resize
      init_args:
        height: 224
        width: 224
    - class_path: ToTensorV2
Note the class_path is TorchNonGeoDataModule and the class to be used is passed through cls (there is also a TorchGeoDataModule for geo modules). This has to be done as the transforms argument is passed through **kwargs in TorchGeo, making it difficult to instantiate with LightningCLI. See more details below.

terratorch.datamodules.torchgeo_data_module

Ugly proxy objects so parsing config file works with transforms.

These are necessary since, for LightningCLI to instantiate arguments as objects from the config, they must have type annotations

In TorchGeo, transforms is passed in **kwargs, so it has no type annotations! To get around that, we create these wrappers that have transforms type annotated. They create the transforms and forward all method and attribute calls to the original TorchGeo datamodule.

Additionally, TorchGeo datasets pass the data to the transforms callable as a dict, and as a tensor.

Albumentations expects this data not as a dict but as different key-value arguments, and as numpy. We handle that conversion here.

TorchGeoDataModule

Bases: GeoDataModule

Proxy object for using Geo data modules defined by TorchGeo.

Allows for transforms to be defined and passed using config files. The only reason this class exists is so that we can annotate the transforms argument with a type. This is required for lightningcli and config files. As such, all getattr and setattr will be redirected to the underlying class.

Source code in terratorch/datamodules/torchgeo_data_module.py
class TorchGeoDataModule(GeoDataModule):
    """Proxy object for using Geo data modules defined by TorchGeo.

    Allows for transforms to be defined and passed using config files.
    The only reason this class exists is so that we can annotate the transforms argument with a type.
    This is required for lightningcli and config files.
    As such, all getattr and setattr will be redirected to the underlying class.
    """

    def __init__(
        self,
        cls: type[GeoDataModule],
        batch_size: int | None = None,
        num_workers: int = 0,
        transforms: None | list[BasicTransform] = None,
        **kwargs: Any,
    ):
        """Constructor

        Args:
            cls (type[GeoDataModule]): TorchGeo DataModule class to be instantiated
            batch_size (int | None, optional): batch_size. Defaults to None.
            num_workers (int, optional): num_workers. Defaults to 0.
            transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
                Should enc with ToTensorV2. Defaults to None.
            **kwargs (Any): Arguments passed to instantiate `cls`.
        """
        if batch_size is not None:
            kwargs["batch_size"] = batch_size
        if transforms is not None:
            transforms_as_callable = albumentations_to_callable_with_dict(transforms)
            kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
        # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
        self._proxy = cls(num_workers=num_workers, **kwargs)
        super().__init__(self._proxy.dataset_class)  # dummy arg

    @property
    def collate_fn(self):
        return self._proxy.collate_fn

    @collate_fn.setter
    def collate_fn(self, value):
        self._proxy.collate_fn = value

    @property
    def patch_size(self):
        return self._proxy.patch_size

    @property
    def length(self):
        return self._proxy.length

    def setup(self, stage: str):
        return self._proxy.setup(stage)

    def train_dataloader(self):
        return self._proxy.train_dataloader()

    def val_dataloader(self):
        return self._proxy.val_dataloader()

    def test_dataloader(self):
        return self._proxy.test_dataloader()

    def predict_dataloader(self):
        return self._proxy.predict_dataloader()

    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        return self._proxy.predict_dataloader(batch, device, dataloader_idx)

__init__(cls, batch_size=None, num_workers=0, transforms=None, **kwargs)

Constructor

Parameters:
  • cls (type[GeoDataModule]) –

    TorchGeo DataModule class to be instantiated

  • batch_size (int | None, default: None ) –

    batch_size. Defaults to None.

  • num_workers (int, default: 0 ) –

    num_workers. Defaults to 0.

  • transforms (None | list[BasicTransform], default: None ) –

    List of Albumentations Transforms. Should enc with ToTensorV2. Defaults to None.

  • **kwargs (Any, default: {} ) –

    Arguments passed to instantiate cls.

Source code in terratorch/datamodules/torchgeo_data_module.py
def __init__(
    self,
    cls: type[GeoDataModule],
    batch_size: int | None = None,
    num_workers: int = 0,
    transforms: None | list[BasicTransform] = None,
    **kwargs: Any,
):
    """Constructor

    Args:
        cls (type[GeoDataModule]): TorchGeo DataModule class to be instantiated
        batch_size (int | None, optional): batch_size. Defaults to None.
        num_workers (int, optional): num_workers. Defaults to 0.
        transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
            Should enc with ToTensorV2. Defaults to None.
        **kwargs (Any): Arguments passed to instantiate `cls`.
    """
    if batch_size is not None:
        kwargs["batch_size"] = batch_size
    if transforms is not None:
        transforms_as_callable = albumentations_to_callable_with_dict(transforms)
        kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
    # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
    self._proxy = cls(num_workers=num_workers, **kwargs)
    super().__init__(self._proxy.dataset_class)  # dummy arg

TorchNonGeoDataModule

Bases: NonGeoDataModule

Proxy object for using NonGeo data modules defined by TorchGeo.

Allows for transforms to be defined and passed using config files. The only reason this class exists is so that we can annotate the transforms argument with a type. This is required for lightningcli and config files. As such, all getattr and setattr will be redirected to the underlying class.

Source code in terratorch/datamodules/torchgeo_data_module.py
class TorchNonGeoDataModule(NonGeoDataModule):
    """Proxy object for using NonGeo data modules defined by TorchGeo.

    Allows for transforms to be defined and passed using config files.
    The only reason this class exists is so that we can annotate the transforms argument with a type.
    This is required for lightningcli and config files.
    As such, all getattr and setattr will be redirected to the underlying class.
    """

    def __init__(
        self,
        cls: type[NonGeoDataModule],
        batch_size: int | None = None,
        num_workers: int = 0,
        transforms: None | list[BasicTransform] = None,
        **kwargs: Any,
    ):
        """Constructor

        Args:
            cls (type[NonGeoDataModule]): TorchGeo DataModule class to be instantiated
            batch_size (int | None, optional): batch_size. Defaults to None.
            num_workers (int, optional): num_workers. Defaults to 0.
            transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
                Should enc with ToTensorV2. Defaults to None.
            **kwargs (Any): Arguments passed to instantiate `cls`.
        """
        if batch_size is not None:
            kwargs["batch_size"] = batch_size
        if transforms is not None:
            transforms_as_callable = albumentations_to_callable_with_dict(transforms)
            kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
        # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
        self._proxy = cls(num_workers=num_workers, **kwargs)
        super().__init__(self._proxy.dataset_class)  # dummy arg

    @property
    def collate_fn(self):
        return self._proxy.collate_fn

    @collate_fn.setter
    def collate_fn(self, value):
        self._proxy.collate_fn = value

    def setup(self, stage: str):
        return self._proxy.setup(stage)

    def train_dataloader(self):
        return self._proxy.train_dataloader()

    def val_dataloader(self):
        return self._proxy.val_dataloader()

    def test_dataloader(self):
        return self._proxy.test_dataloader()

    def predict_dataloader(self):
        return self._proxy.predict_dataloader()

__init__(cls, batch_size=None, num_workers=0, transforms=None, **kwargs)

Constructor

Parameters:
  • cls (type[NonGeoDataModule]) –

    TorchGeo DataModule class to be instantiated

  • batch_size (int | None, default: None ) –

    batch_size. Defaults to None.

  • num_workers (int, default: 0 ) –

    num_workers. Defaults to 0.

  • transforms (None | list[BasicTransform], default: None ) –

    List of Albumentations Transforms. Should enc with ToTensorV2. Defaults to None.

  • **kwargs (Any, default: {} ) –

    Arguments passed to instantiate cls.

Source code in terratorch/datamodules/torchgeo_data_module.py
def __init__(
    self,
    cls: type[NonGeoDataModule],
    batch_size: int | None = None,
    num_workers: int = 0,
    transforms: None | list[BasicTransform] = None,
    **kwargs: Any,
):
    """Constructor

    Args:
        cls (type[NonGeoDataModule]): TorchGeo DataModule class to be instantiated
        batch_size (int | None, optional): batch_size. Defaults to None.
        num_workers (int, optional): num_workers. Defaults to 0.
        transforms (None | list[BasicTransform], optional): List of Albumentations Transforms.
            Should enc with ToTensorV2. Defaults to None.
        **kwargs (Any): Arguments passed to instantiate `cls`.
    """
    if batch_size is not None:
        kwargs["batch_size"] = batch_size
    if transforms is not None:
        transforms_as_callable = albumentations_to_callable_with_dict(transforms)
        kwargs["transforms"] = build_callable_transform_from_torch_tensor(transforms_as_callable)
    # self.__dict__["datamodule"] = cls(num_workers=num_workers, **kwargs)
    self._proxy = cls(num_workers=num_workers, **kwargs)
    super().__init__(self._proxy.dataset_class)  # dummy arg

Generic datasets and data modules

For the NonGeoDataset case, we also provide "generic" datasets and datamodules. These can be used when you would like to load data from given directories, in a style similar to the MMLab libraries.

Generic Datasets

terratorch.datasets.generic_pixel_wise_dataset

Module containing generic dataset classes

GenericNonGeoPixelwiseRegressionDataset

Bases: GenericPixelWiseDataset

GenericNonGeoPixelwiseRegressionDataset

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericNonGeoPixelwiseRegressionDataset(GenericPixelWiseDataset):
    """GenericNonGeoPixelwiseRegressionDataset"""

    def __init__(
        self,
        data_root: Path,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__(
            data_root,
            label_data_root=label_data_root,
            image_grep=image_grep,
            label_grep=label_grep,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            no_label_replace=no_label_replace,
            expand_temporal_dimension=expand_temporal_dimension,
            reduce_zero_label=reduce_zero_label,
        )

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["mask"] = item["mask"].float()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, Tensor]): a sample returned by :meth:`__getitem__`
            suptitle (str|None): optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample

        .. versionadded:: 0.2
        """
        image = sample["image"]
        if len(image.shape) == 5:
            return
        if isinstance(image, Tensor):
            image = image.numpy()
        image = image.take(self.rgb_indices, axis=0)
        image = np.transpose(image, (1, 2, 0))
        image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
        image = np.clip(image, 0, 1)

        label_mask = sample["mask"]
        if isinstance(label_mask, Tensor):
            label_mask = label_mask.numpy()

        showing_predictions = "prediction" in sample
        if showing_predictions:
            prediction_mask = sample["prediction"]
            if isinstance(prediction_mask, Tensor):
                prediction_mask = prediction_mask.numpy()

        return self._plot_sample(
            image,
            label_mask,
            prediction=prediction_mask if showing_predictions else None,
            suptitle=suptitle,
        )

    @staticmethod
    def _plot_sample(image, label, prediction=None, suptitle=None):
        num_images = 4 if prediction is not None else 3
        fig, ax = plt.subplots(1, num_images, figsize=(12, 10), layout="compressed")

        norm = mpl.colors.Normalize(vmin=label.min(), vmax=label.max())
        ax[0].axis("off")
        ax[0].title.set_text("Image")
        ax[0].imshow(image)

        ax[1].axis("off")
        ax[1].title.set_text("Ground Truth Mask")
        ax[1].imshow(label, cmap="Greens", norm=norm)

        ax[2].axis("off")
        ax[2].title.set_text("GT Mask on Image")
        ax[2].imshow(image)
        ax[2].imshow(label, cmap="Greens", alpha=0.3, norm=norm)
        # ax[2].legend()

        if prediction is not None:
            ax[3].title.set_text("Predicted Mask")
            ax[3].imshow(prediction, cmap="Greens", norm=norm)

        if suptitle is not None:
            plt.suptitle(suptitle)
        return fig
__init__(data_root, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • label_data_root (Path, default: None ) –

    Path to data root directory with labels. If not specified, will use the same as for images.

  • image_grep (str, default: '*' ) –

    Regular expression appended to data_root to find input images. Defaults to "*".

  • label_grep (str, default: '*' ) –

    Regular expression appended to data_root to find ground truth masks. Defaults to "*".

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__(
        data_root,
        label_data_root=label_data_root,
        image_grep=image_grep,
        label_grep=label_grep,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        no_label_replace=no_label_replace,
        expand_temporal_dimension=expand_temporal_dimension,
        reduce_zero_label=reduce_zero_label,
    )
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • suptitle (str | None, default: None ) –

    optional string to use as a suptitle

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

.. versionadded:: 0.2

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, Tensor]): a sample returned by :meth:`__getitem__`
        suptitle (str|None): optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample

    .. versionadded:: 0.2
    """
    image = sample["image"]
    if len(image.shape) == 5:
        return
    if isinstance(image, Tensor):
        image = image.numpy()
    image = image.take(self.rgb_indices, axis=0)
    image = np.transpose(image, (1, 2, 0))
    image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
    image = np.clip(image, 0, 1)

    label_mask = sample["mask"]
    if isinstance(label_mask, Tensor):
        label_mask = label_mask.numpy()

    showing_predictions = "prediction" in sample
    if showing_predictions:
        prediction_mask = sample["prediction"]
        if isinstance(prediction_mask, Tensor):
            prediction_mask = prediction_mask.numpy()

    return self._plot_sample(
        image,
        label_mask,
        prediction=prediction_mask if showing_predictions else None,
        suptitle=suptitle,
    )
GenericNonGeoSegmentationDataset

Bases: GenericPixelWiseDataset

GenericNonGeoSegmentationDataset

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
    """GenericNonGeoSegmentationDataset"""

    def __init__(
        self,
        data_root: Path,
        num_classes: int,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[str] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        class_names: list[str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            num_classes (int): Number of classes in the dataset
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            class_names (list[str], optional): Class names. Defaults to None.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__(
            data_root,
            label_data_root=label_data_root,
            image_grep=image_grep,
            label_grep=label_grep,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            no_label_replace=no_label_replace,
            expand_temporal_dimension=expand_temporal_dimension,
            reduce_zero_label=reduce_zero_label,
        )
        self.num_classes = num_classes
        self.class_names = class_names

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["mask"] = item["mask"].long()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample

        .. versionadded:: 0.2
        """
        image = sample["image"]
        if len(image.shape) == 5:
            return
        if isinstance(image, Tensor):
            image = image.numpy()
        image = image.take(self.rgb_indices, axis=0)
        image = np.transpose(image, (1, 2, 0))
        image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
        image = np.clip(image, 0, 1)

        label_mask = sample["mask"]
        if isinstance(label_mask, Tensor):
            label_mask = label_mask.numpy()

        showing_predictions = "prediction" in sample
        if showing_predictions:
            prediction_mask = sample["prediction"]
            if isinstance(prediction_mask, Tensor):
                prediction_mask = prediction_mask.numpy()

        return self._plot_sample(
            image,
            label_mask,
            self.num_classes,
            prediction=prediction_mask if showing_predictions else None,
            suptitle=suptitle,
            class_names=self.class_names,
        )

    @staticmethod
    def _plot_sample(image, label, num_classes, prediction=None, suptitle=None, class_names=None):
        num_images = 5 if prediction is not None else 4
        fig, ax = plt.subplots(1, num_images, figsize=(12, 10), layout="compressed")

        # for legend
        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=num_classes - 1)
        ax[1].axis("off")
        ax[1].title.set_text("Image")
        ax[1].imshow(image)

        ax[2].axis("off")
        ax[2].title.set_text("Ground Truth Mask")
        ax[2].imshow(label, cmap="jet", norm=norm)

        ax[3].axis("off")
        ax[3].title.set_text("GT Mask on Image")
        ax[3].imshow(image)
        ax[3].imshow(label, cmap="jet", alpha=0.3, norm=norm)

        if prediction is not None:
            ax[4].title.set_text("Predicted Mask")
            ax[4].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = []
        for i, _ in enumerate(range(num_classes)):
            class_name = class_names[i] if class_names else str(i)
            data = [i, cmap(norm(i)), class_name]
            legend_data.append(data)
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")
        if suptitle is not None:
            plt.suptitle(suptitle)
        return fig
__init__(data_root, num_classes, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, class_names=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • num_classes (int) –

    Number of classes in the dataset

  • label_data_root (Path, default: None ) –

    Path to data root directory with labels. If not specified, will use the same as for images.

  • image_grep (str, default: '*' ) –

    Regular expression appended to data_root to find input images. Defaults to "*".

  • label_grep (str, default: '*' ) –

    Regular expression appended to data_root to find ground truth masks. Defaults to "*".

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset.

  • class_names (list[str], default: None ) –

    Class names. Defaults to None.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    num_classes: int,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[str] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    class_names: list[str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        num_classes (int): Number of classes in the dataset
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        class_names (list[str], optional): Class names. Defaults to None.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__(
        data_root,
        label_data_root=label_data_root,
        image_grep=image_grep,
        label_grep=label_grep,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        no_label_replace=no_label_replace,
        expand_temporal_dimension=expand_temporal_dimension,
        reduce_zero_label=reduce_zero_label,
    )
    self.num_classes = num_classes
    self.class_names = class_names
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • suptitle (str | None, default: None ) –

    optional string to use as a suptitle

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

.. versionadded:: 0.2

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample

    .. versionadded:: 0.2
    """
    image = sample["image"]
    if len(image.shape) == 5:
        return
    if isinstance(image, Tensor):
        image = image.numpy()
    image = image.take(self.rgb_indices, axis=0)
    image = np.transpose(image, (1, 2, 0))
    image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
    image = np.clip(image, 0, 1)

    label_mask = sample["mask"]
    if isinstance(label_mask, Tensor):
        label_mask = label_mask.numpy()

    showing_predictions = "prediction" in sample
    if showing_predictions:
        prediction_mask = sample["prediction"]
        if isinstance(prediction_mask, Tensor):
            prediction_mask = prediction_mask.numpy()

    return self._plot_sample(
        image,
        label_mask,
        self.num_classes,
        prediction=prediction_mask if showing_predictions else None,
        suptitle=suptitle,
        class_names=self.class_names,
    )
GenericPixelWiseDataset

Bases: NonGeoDataset, ABC

This is a generic dataset class to be used for instantiating datasets from arguments. Ideally, one would create a dataset class specific to a dataset.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
class GenericPixelWiseDataset(NonGeoDataset, ABC):
    """
    This is a generic dataset class to be used for instantiating datasets from arguments.
    Ideally, one would create a dataset class specific to a dataset.
    """

    def __init__(
        self,
        data_root: Path,
        label_data_root: Path | None = None,
        image_grep: str | None = "*",
        label_grep: str | None = "*",
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            label_data_root (Path, optional): Path to data root directory with labels.
                If not specified, will use the same as for images.
            image_grep (str, optional): Regular expression appended to data_root to find input images.
                Defaults to "*".
            label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
                Defaults to "*".
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
            output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to -1.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
        """
        super().__init__()

        self.split_file = split

        label_data_root = label_data_root if label_data_root is not None else data_root
        self.image_files = sorted(glob.glob(os.path.join(data_root, image_grep)))
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_data_root, label_grep)))
        self.reduce_zero_label = reduce_zero_label
        self.expand_temporal_dimension = expand_temporal_dimension

        if self.expand_temporal_dimension and output_bands is None:
            msg = "Please provide output_bands when expand_temporal_dimension is True"
            raise Exception(msg)
        if self.split_file is not None:
            with open(self.split_file) as f:
                split = f.readlines()
            valid_files = {rf"{substring.strip()}" for substring in split}
            self.image_files = filter_valid_files(
                self.image_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )
            self.segmentation_mask_files = filter_valid_files(
                self.segmentation_mask_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )
        self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

        self.dataset_bands = generate_bands_intervals(dataset_bands)
        self.output_bands = generate_bands_intervals(output_bands)

        if self.output_bands and not self.dataset_bands:
            msg = "If output bands provided, dataset_bands must also be provided"
            return Exception(msg)  # noqa: PLE0101

        # There is a special condition if the bands are defined as simple strings.
        if self.output_bands:
            if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
                msg = "Output bands must be a subset of dataset bands"
                raise Exception(msg)

            self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

        else:
            self.filter_indices = None

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform
        # self.transform = transform if transform else ToTensorV2()

        import warnings

        import rasterio

        warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace).to_numpy()
        # to channels last
        if self.expand_temporal_dimension:
            image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands))
        image = np.moveaxis(image, 0, -1)

        if self.filter_indices:
            image = image[..., self.filter_indices]
        output = {
            "image": image.astype(np.float32) * self.constant_scale,
            "mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[
                0
            ]
        }

        if self.reduce_zero_label:
            output["mask"] -= 1
        if self.transform:
            output = self.transform(**output)
        output["filename"] = self.image_files[index]

        return output

    def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data
__init__(data_root, label_data_root=None, image_grep='*', label_grep='*', split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=False, reduce_zero_label=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • label_data_root (Path, default: None ) –

    Path to data root directory with labels. If not specified, will use the same as for images.

  • image_grep (str, default: '*' ) –

    Regular expression appended to data_root to find input images. Defaults to "*".

  • label_grep (str, default: '*' ) –

    Regular expression appended to data_root to find ground truth masks. Defaults to "*".

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.

  • output_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands that should be output by the dataset as named by dataset_bands.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to -1.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

Source code in terratorch/datasets/generic_pixel_wise_dataset.py
def __init__(
    self,
    data_root: Path,
    label_data_root: Path | None = None,
    image_grep: str | None = "*",
    label_grep: str | None = "*",
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        label_data_root (Path, optional): Path to data root directory with labels.
            If not specified, will use the same as for images.
        image_grep (str, optional): Regular expression appended to data_root to find input images.
            Defaults to "*".
        label_grep (str, optional): Regular expression appended to data_root to find ground truth masks.
            Defaults to "*".
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
        output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to -1.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
    """
    super().__init__()

    self.split_file = split

    label_data_root = label_data_root if label_data_root is not None else data_root
    self.image_files = sorted(glob.glob(os.path.join(data_root, image_grep)))
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_data_root, label_grep)))
    self.reduce_zero_label = reduce_zero_label
    self.expand_temporal_dimension = expand_temporal_dimension

    if self.expand_temporal_dimension and output_bands is None:
        msg = "Please provide output_bands when expand_temporal_dimension is True"
        raise Exception(msg)
    if self.split_file is not None:
        with open(self.split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )
        self.segmentation_mask_files = filter_valid_files(
            self.segmentation_mask_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )
    self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

    self.dataset_bands = generate_bands_intervals(dataset_bands)
    self.output_bands = generate_bands_intervals(output_bands)

    if self.output_bands and not self.dataset_bands:
        msg = "If output bands provided, dataset_bands must also be provided"
        return Exception(msg)  # noqa: PLE0101

    # There is a special condition if the bands are defined as simple strings.
    if self.output_bands:
        if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
            msg = "Output bands must be a subset of dataset bands"
            raise Exception(msg)

        self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

    else:
        self.filter_indices = None

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform
    # self.transform = transform if transform else ToTensorV2()

    import warnings

    import rasterio

    warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

terratorch.datasets.generic_scalar_label_dataset

Module containing generic dataset classes

GenericNonGeoClassificationDataset

Bases: GenericScalarLabelDataset

GenericNonGeoClassificationDataset

Source code in terratorch/datasets/generic_scalar_label_dataset.py
class GenericNonGeoClassificationDataset(GenericScalarLabelDataset):
    """GenericNonGeoClassificationDataset"""

    def __init__(
        self,
        data_root: Path,
        num_classes: int,
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
        allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
        rgb_indices: list[str] | None = None,
        dataset_bands: list[HLSBands | int] | None = None,
        output_bands: list[HLSBands | int] | None = None,
        class_names: list[str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float = 0,
        expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
    ) -> None:
        """A generic Non-Geo dataset for classification.

        Args:
            data_root (Path): Path to data root directory
            num_classes (int): Number of classes in the dataset
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            class_names (list[str], optional): Class names. Defaults to None.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
        """
        super().__init__(
            data_root,
            split=split,
            ignore_split_file_extensions=ignore_split_file_extensions,
            allow_substring_split_file=allow_substring_split_file,
            rgb_indices=rgb_indices,
            dataset_bands=dataset_bands,
            output_bands=output_bands,
            constant_scale=constant_scale,
            transform=transform,
            no_data_replace=no_data_replace,
            expand_temporal_dimension=expand_temporal_dimension,
        )
        self.num_classes = num_classes
        self.class_names = class_names

    def __getitem__(self, index: int) -> dict[str, Any]:
        item = super().__getitem__(index)
        item["label"] = torch.tensor(item["label"]).long()
        return item

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        pass
__init__(data_root, num_classes, split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, class_names=None, constant_scale=1, transform=None, no_data_replace=0, expand_temporal_dimension=False)

A generic Non-Geo dataset for classification.

Parameters:
  • data_root (Path) –

    Path to data root directory

  • num_classes (int) –

    Number of classes in the dataset

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset.

  • class_names (list[str], default: None ) –

    Class names. Defaults to None.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float, default: 0 ) –

    Replace nan values in input images with this value. Defaults to 0.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

Source code in terratorch/datasets/generic_scalar_label_dataset.py
def __init__(
    self,
    data_root: Path,
    num_classes: int,
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
    allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
    rgb_indices: list[str] | None = None,
    dataset_bands: list[HLSBands | int] | None = None,
    output_bands: list[HLSBands | int] | None = None,
    class_names: list[str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float = 0,
    expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
) -> None:
    """A generic Non-Geo dataset for classification.

    Args:
        data_root (Path): Path to data root directory
        num_classes (int): Number of classes in the dataset
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
        class_names (list[str], optional): Class names. Defaults to None.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
    """
    super().__init__(
        data_root,
        split=split,
        ignore_split_file_extensions=ignore_split_file_extensions,
        allow_substring_split_file=allow_substring_split_file,
        rgb_indices=rgb_indices,
        dataset_bands=dataset_bands,
        output_bands=output_bands,
        constant_scale=constant_scale,
        transform=transform,
        no_data_replace=no_data_replace,
        expand_temporal_dimension=expand_temporal_dimension,
    )
    self.num_classes = num_classes
    self.class_names = class_names
GenericScalarLabelDataset

Bases: NonGeoDataset, ImageFolder, ABC

This is a generic dataset class to be used for instantiating datasets from arguments. Ideally, one would create a dataset class specific to a dataset.

Source code in terratorch/datasets/generic_scalar_label_dataset.py
class GenericScalarLabelDataset(NonGeoDataset, ImageFolder, ABC):
    """
    This is a generic dataset class to be used for instantiating datasets from arguments.
    Ideally, one would create a dataset class specific to a dataset.
    """

    def __init__(
        self,
        data_root: Path,
        split: Path | None = None,
        ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
        allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
        rgb_indices: list[int] | None = None,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        constant_scale: float = 1,
        transform: A.Compose | None = None,
        no_data_replace: float = 0,
        expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
    ) -> None:
        """Constructor

        Args:
            data_root (Path): Path to data root directory
            split (Path, optional): Path to file containing files to be used for this split.
                The file should be a new-line separated prefixes contained in the desired files.
                Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
            dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This
                parameter gives identifiers to input channels (bands) so that they can then be refered to by
                output_bands. Can use the HLSBands enum, ints, int ranges, or strings. Defaults to None.
            output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the
                dataset as named by dataset_bands.
            constant_scale (float): Factor to multiply image values by. Defaults to 1.
            transform (Albumentations.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
        """
        self.split_file = split

        self.image_files = sorted(glob.glob(os.path.join(data_root, "**"), recursive=True))
        self.image_files = [f for f in self.image_files if not os.path.isdir(f)]
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.expand_temporal_dimension = expand_temporal_dimension
        if self.expand_temporal_dimension and output_bands is None:
            msg = "Please provide output_bands when expand_temporal_dimension is True"
            raise Exception(msg)
        if self.split_file is not None:
            with open(self.split_file) as f:
                split = f.readlines()
            valid_files = {rf"{substring.strip()}" for substring in split}
            self.image_files = filter_valid_files(
                self.image_files,
                valid_files=valid_files,
                ignore_extensions=ignore_split_file_extensions,
                allow_substring=allow_substring_split_file,
            )

            def is_valid_file(x):
                return x in self.image_files

        else:

            def is_valid_file(x):
                return True

        super().__init__(
            root=data_root, transform=None, target_transform=None, loader=rasterio_loader, is_valid_file=is_valid_file
        )

        self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

        self.dataset_bands = generate_bands_intervals(dataset_bands)
        self.output_bands = generate_bands_intervals(output_bands)

        if self.output_bands and not self.dataset_bands:
            msg = "If output bands provided, dataset_bands must also be provided"
            return Exception(msg)  # noqa: PLE0101

        # There is a special condition if the bands are defined as simple strings.
        if self.output_bands:
            if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
                msg = "Output bands must be a subset of dataset bands"
                raise Exception(msg)

            self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

        else:
            self.filter_indices = None
        # If no transform is given, apply only to transform to torch tensor
        self.transforms = transform if transform else default_transform
        # self.transform = transform if transform else ToTensorV2()

        import warnings

        import rasterio
        warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image, label = ImageFolder.__getitem__(self, index)
        if self.expand_temporal_dimension:
            image = rearrange(image, "h w (channels time) -> time h w channels", channels=len(self.output_bands))
        if self.filter_indices:
            image = image[..., self.filter_indices]

        image = image.astype(np.float32) * self.constant_scale

        if self.transforms:
            image = self.transforms(image=image)["image"]  # albumentations returns dict

        output = {
            "image": image,
            "label": label,  # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
            "filename": self.image_files[index]
        }

        return output

    def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None):
        if bands_intervals is None:
            return None
        bands = []
        for element in bands_intervals:
            # if its an interval
            if isinstance(element, tuple):
                if len(element) != 2:  # noqa: PLR2004
                    msg = "When defining an interval, a tuple of two integers should be passed,\
                    defining start and end indices inclusive"
                    raise Exception(msg)
                expanded_element = list(range(element[0], element[1] + 1))
                bands.extend(expanded_element)
            else:
                bands.append(element)
        return bands

    def _load_file(self, path) -> xr.DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        data = data.fillna(self.no_data_replace)
        return data
__init__(data_root, split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, rgb_indices=None, dataset_bands=None, output_bands=None, constant_scale=1, transform=None, no_data_replace=0, expand_temporal_dimension=False)

Constructor

Parameters:
  • data_root (Path) –

    Path to data root directory

  • split (Path, default: None ) –

    Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • rgb_indices (list[str], default: None ) –

    Indices of RGB channels. Defaults to [0, 1, 2].

  • dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands present in the dataset. This parameter gives identifiers to input channels (bands) so that they can then be refered to by output_bands. Can use the HLSBands enum, ints, int ranges, or strings. Defaults to None.

  • output_bands (list[HLSBands | int | tuple[int, int] | str] | None, default: None ) –

    Bands that should be output by the dataset as named by dataset_bands.

  • constant_scale (float, default: 1 ) –

    Factor to multiply image values by. Defaults to 1.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float, default: 0 ) –

    Replace nan values in input images with this value. Defaults to 0.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

Source code in terratorch/datasets/generic_scalar_label_dataset.py
def __init__(
    self,
    data_root: Path,
    split: Path | None = None,
    ignore_split_file_extensions: bool = True,  # noqa: FBT001, FBT002
    allow_substring_split_file: bool = True,  # noqa: FBT001, FBT002
    rgb_indices: list[int] | None = None,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    constant_scale: float = 1,
    transform: A.Compose | None = None,
    no_data_replace: float = 0,
    expand_temporal_dimension: bool = False,  # noqa: FBT001, FBT002
) -> None:
    """Constructor

    Args:
        data_root (Path): Path to data root directory
        split (Path, optional): Path to file containing files to be used for this split.
            The file should be a new-line separated prefixes contained in the desired files.
            Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep])
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
        dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This
            parameter gives identifiers to input channels (bands) so that they can then be refered to by
            output_bands. Can use the HLSBands enum, ints, int ranges, or strings. Defaults to None.
        output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the
            dataset as named by dataset_bands.
        constant_scale (float): Factor to multiply image values by. Defaults to 1.
        transform (Albumentations.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
    """
    self.split_file = split

    self.image_files = sorted(glob.glob(os.path.join(data_root, "**"), recursive=True))
    self.image_files = [f for f in self.image_files if not os.path.isdir(f)]
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.expand_temporal_dimension = expand_temporal_dimension
    if self.expand_temporal_dimension and output_bands is None:
        msg = "Please provide output_bands when expand_temporal_dimension is True"
        raise Exception(msg)
    if self.split_file is not None:
        with open(self.split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=ignore_split_file_extensions,
            allow_substring=allow_substring_split_file,
        )

        def is_valid_file(x):
            return x in self.image_files

    else:

        def is_valid_file(x):
            return True

    super().__init__(
        root=data_root, transform=None, target_transform=None, loader=rasterio_loader, is_valid_file=is_valid_file
    )

    self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

    self.dataset_bands = generate_bands_intervals(dataset_bands)
    self.output_bands = generate_bands_intervals(output_bands)

    if self.output_bands and not self.dataset_bands:
        msg = "If output bands provided, dataset_bands must also be provided"
        return Exception(msg)  # noqa: PLE0101

    # There is a special condition if the bands are defined as simple strings.
    if self.output_bands:
        if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
            msg = "Output bands must be a subset of dataset bands"
            raise Exception(msg)

        self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

    else:
        self.filter_indices = None
    # If no transform is given, apply only to transform to torch tensor
    self.transforms = transform if transform else default_transform
    # self.transform = transform if transform else ToTensorV2()

    import warnings

    import rasterio
    warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

Generic Data Modules

terratorch.datamodules.generic_pixel_wise_data_module

This module contains generic data modules for instantiation at runtime.

GenericNonGeoPixelwiseRegressionDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoPixelwiseRegressionDataset

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
class GenericNonGeoPixelwiseRegressionDataModule(NonGeoDataModule):
    """This is a generic datamodule class for instantiating data modules at runtime.
    Composes several
    [GenericNonGeoPixelwiseRegressionDataset][terratorch.datasets.GenericNonGeoPixelwiseRegressionDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        means: list[float] | str,
        stds: list[float] | str,
        predict_data_root: Path | None = None,
        img_grep: str | None = "*",
        label_grep: str | None = "*",
        train_label_data_root: Path | None = None,
        val_label_data_root: Path | None = None,
        test_label_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        drop_last: bool = True,
        pin_memory: bool = False,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            predict_data_root (Path): _description_
            img_grep (str): _description_
            label_grep (str): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            train_label_data_root (Path | None, optional): _description_. Defaults to None.
            val_label_data_root (Path | None, optional): _description_. Defaults to None.
            test_label_data_root (Path | None, optional): _description_. Defaults to None.
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
                Naming must match that of dataset_bands. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
                with this value at predict time.
                Defaults to None, which does not overwrite.
            predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
                with this value at predict time. Defaults to None, which does not overwrite.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
            pin_memory (bool): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them. Defaults to False.

        """
        super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs)
        self.img_grep = img_grep
        self.label_grep = label_grep
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.drop_last = drop_last
        self.pin_memory = pin_memory
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label

        self.train_label_data_root = train_label_data_root
        self.val_label_data_root = val_label_data_root
        self.test_label_data_root = test_label_data_root

        self.constant_scale = constant_scale

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )
        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.train_label_data_root,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.val_label_data_root,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.test_label_data_root,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                image_grep=self.img_grep,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.predict_output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
            pin_memory=self.pin_memory,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, means, stds, predict_data_root=None, img_grep='*', label_grep='*', train_label_data_root=None, val_label_data_root=None, test_label_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, output_bands=None, predict_dataset_bands=None, predict_output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, reduce_zero_label=False, no_data_replace=None, no_label_replace=None, drop_last=True, pin_memory=False, **kwargs)

Constructor

Parameters:
  • batch_size (int) –

    description

  • num_workers (int) –

    description

  • train_data_root (Path) –

    description

  • val_data_root (Path) –

    description

  • test_data_root (Path) –

    description

  • predict_data_root (Path, default: None ) –

    description

  • img_grep (str, default: '*' ) –

    description

  • label_grep (str, default: '*' ) –

    description

  • means (list[float]) –

    description

  • stds (list[float]) –

    description

  • train_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • val_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • test_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • train_split (Path | None, default: None ) –

    description. Defaults to None.

  • val_split (Path | None, default: None ) –

    description. Defaults to None.

  • test_split (Path | None, default: None ) –

    description. Defaults to None.

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset. Defaults to None.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset. Naming must match that of dataset_bands. Defaults to None.

  • predict_dataset_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites dataset_bands with this value at predict time. Defaults to None, which does not overwrite.

  • predict_output_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites output_bands with this value at predict time. Defaults to None, which does not overwrite.

  • constant_scale (float, default: 1 ) –

    description. Defaults to 1.

  • rgb_indices (list[int] | None, default: None ) –

    description. Defaults to None.

  • train_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • val_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • test_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

  • drop_last (bool, default: True ) –

    Drop the last batch if it is not complete. Defaults to True.

  • pin_memory (bool, default: False ) –

    If True, the data loader will copy Tensors

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    means: list[float] | str,
    stds: list[float] | str,
    predict_data_root: Path | None = None,
    img_grep: str | None = "*",
    label_grep: str | None = "*",
    train_label_data_root: Path | None = None,
    val_label_data_root: Path | None = None,
    test_label_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    drop_last: bool = True,
    pin_memory: bool = False,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        predict_data_root (Path): _description_
        img_grep (str): _description_
        label_grep (str): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        train_label_data_root (Path | None, optional): _description_. Defaults to None.
        val_label_data_root (Path | None, optional): _description_. Defaults to None.
        test_label_data_root (Path | None, optional): _description_. Defaults to None.
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            Naming must match that of dataset_bands. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
            with this value at predict time.
            Defaults to None, which does not overwrite.
        predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
            with this value at predict time. Defaults to None, which does not overwrite.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        pin_memory (bool): If ``True``, the data loader will copy Tensors
        into device/CUDA pinned memory before returning them. Defaults to False.

    """
    super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs)
    self.img_grep = img_grep
    self.label_grep = label_grep
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.drop_last = drop_last
    self.pin_memory = pin_memory
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label

    self.train_label_data_root = train_label_data_root
    self.val_label_data_root = val_label_data_root
    self.test_label_data_root = test_label_data_root

    self.constant_scale = constant_scale

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )
    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
GenericNonGeoSegmentationDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoSegmentationDatasets

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
class GenericNonGeoSegmentationDataModule(NonGeoDataModule):
    """
    This is a generic datamodule class for instantiating data modules at runtime.
    Composes several [GenericNonGeoSegmentationDatasets][terratorch.datasets.GenericNonGeoSegmentationDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        img_grep: str,
        label_grep: str,
        means: list[float] | str,
        stds: list[float] | str,
        num_classes: int,
        predict_data_root: Path | None = None,
        train_label_data_root: Path | None = None,
        val_label_data_root: Path | None = None,
        test_label_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
        predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        reduce_zero_label: bool = False,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        drop_last: bool = True,
        pin_memory: bool = False,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            predict_data_root (Path): _description_
            img_grep (str): _description_
            label_grep (str): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            num_classes (int): _description_
            train_label_data_root (Path | None, optional): _description_. Defaults to None.
            val_label_data_root (Path | None, optional): _description_. Defaults to None.
            test_label_data_root (Path | None, optional): _description_. Defaults to None.
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg". Defaults to True.
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
            output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
                Naming must match that of dataset_bands. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
                with this value at predict time.
                Defaults to None, which does not overwrite.
            predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
                with this value at predict time. Defaults to None, which does not overwrite.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
            pin_memory (bool): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them. Defaults to False.
        """
        super().__init__(GenericNonGeoSegmentationDataset, batch_size, num_workers, **kwargs)
        self.num_classes = num_classes
        self.img_grep = img_grep
        self.label_grep = label_grep
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.drop_last = drop_last
        self.pin_memory = pin_memory

        self.train_label_data_root = train_label_data_root
        self.val_label_data_root = val_label_data_root
        self.test_label_data_root = test_label_data_root

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )
        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)

        # self.aug = Normalize(means, stds)
        # self.collate_fn = collate_fn_list_dicts

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.train_label_data_root,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.val_label_data_root,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                self.num_classes,
                image_grep=self.img_grep,
                label_grep=self.label_grep,
                label_data_root=self.test_label_data_root,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )
        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                self.num_classes,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.predict_output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
                reduce_zero_label=self.reduce_zero_label,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
            pin_memory=self.pin_memory,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, img_grep, label_grep, means, stds, num_classes, predict_data_root=None, train_label_data_root=None, val_label_data_root=None, test_label_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, output_bands=None, predict_dataset_bands=None, predict_output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, reduce_zero_label=False, no_data_replace=None, no_label_replace=None, drop_last=True, pin_memory=False, **kwargs)

Constructor

Parameters:
  • batch_size (int) –

    description

  • num_workers (int) –

    description

  • train_data_root (Path) –

    description

  • val_data_root (Path) –

    description

  • test_data_root (Path) –

    description

  • predict_data_root (Path, default: None ) –

    description

  • img_grep (str) –

    description

  • label_grep (str) –

    description

  • means (list[float]) –

    description

  • stds (list[float]) –

    description

  • num_classes (int) –

    description

  • train_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • val_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • test_label_data_root (Path | None, default: None ) –

    description. Defaults to None.

  • train_split (Path | None, default: None ) –

    description. Defaults to None.

  • val_split (Path | None, default: None ) –

    description. Defaults to None.

  • test_split (Path | None, default: None ) –

    description. Defaults to None.

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg". Defaults to True.

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Bands present in the dataset. Defaults to None.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    Bands that should be output by the dataset. Naming must match that of dataset_bands. Defaults to None.

  • predict_dataset_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites dataset_bands with this value at predict time. Defaults to None, which does not overwrite.

  • predict_output_bands (list[HLSBands | int] | None, default: None ) –

    Overwrites output_bands with this value at predict time. Defaults to None, which does not overwrite.

  • constant_scale (float, default: 1 ) –

    description. Defaults to 1.

  • rgb_indices (list[int] | None, default: None ) –

    description. Defaults to None.

  • train_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • val_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • test_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If none, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • reduce_zero_label (bool, default: False ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False.

  • drop_last (bool, default: True ) –

    Drop the last batch if it is not complete. Defaults to True.

  • pin_memory (bool, default: False ) –

    If True, the data loader will copy Tensors

Source code in terratorch/datamodules/generic_pixel_wise_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    img_grep: str,
    label_grep: str,
    means: list[float] | str,
    stds: list[float] | str,
    num_classes: int,
    predict_data_root: Path | None = None,
    train_label_data_root: Path | None = None,
    val_label_data_root: Path | None = None,
    test_label_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
    predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    reduce_zero_label: bool = False,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    drop_last: bool = True,
    pin_memory: bool = False,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        predict_data_root (Path): _description_
        img_grep (str): _description_
        label_grep (str): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        num_classes (int): _description_
        train_label_data_root (Path | None, optional): _description_. Defaults to None.
        val_label_data_root (Path | None, optional): _description_. Defaults to None.
        test_label_data_root (Path | None, optional): _description_. Defaults to None.
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg". Defaults to True.
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
        output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
            Naming must match that of dataset_bands. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
            with this value at predict time.
            Defaults to None, which does not overwrite.
        predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
            with this value at predict time. Defaults to None, which does not overwrite.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        pin_memory (bool): If ``True``, the data loader will copy Tensors
        into device/CUDA pinned memory before returning them. Defaults to False.
    """
    super().__init__(GenericNonGeoSegmentationDataset, batch_size, num_workers, **kwargs)
    self.num_classes = num_classes
    self.img_grep = img_grep
    self.label_grep = label_grep
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.drop_last = drop_last
    self.pin_memory = pin_memory

    self.train_label_data_root = train_label_data_root
    self.val_label_data_root = val_label_data_root
    self.test_label_data_root = test_label_data_root

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )
    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)

terratorch.datamodules.generic_scalar_label_data_module

This module contains generic data modules for instantiation at runtime.

GenericNonGeoClassificationDataModule

Bases: NonGeoDataModule

This is a generic datamodule class for instantiating data modules at runtime. Composes several GenericNonGeoClassificationDatasets

Source code in terratorch/datamodules/generic_scalar_label_data_module.py
class GenericNonGeoClassificationDataModule(NonGeoDataModule):
    """
    This is a generic datamodule class for instantiating data modules at runtime.
    Composes several [GenericNonGeoClassificationDatasets][terratorch.datasets.GenericNonGeoClassificationDataset]
    """

    def __init__(
        self,
        batch_size: int,
        num_workers: int,
        train_data_root: Path,
        val_data_root: Path,
        test_data_root: Path,
        means: list[float] | str,
        stds: list[float] | str,
        num_classes: int,
        predict_data_root: Path | None = None,
        train_split: Path | None = None,
        val_split: Path | None = None,
        test_split: Path | None = None,
        ignore_split_file_extensions: bool = True,
        allow_substring_split_file: bool = True,
        dataset_bands: list[HLSBands | int] | None = None,
        predict_dataset_bands: list[HLSBands | int] | None = None,
        output_bands: list[HLSBands | int] | None = None,
        constant_scale: float = 1,
        rgb_indices: list[int] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        expand_temporal_dimension: bool = False,
        no_data_replace: float = 0,
        drop_last: bool = True,
        **kwargs: Any,
    ) -> None:
        """Constructor

        Args:
            batch_size (int): _description_
            num_workers (int): _description_
            train_data_root (Path): _description_
            val_data_root (Path): _description_
            test_data_root (Path): _description_
            means (list[float]): _description_
            stds (list[float]): _description_
            num_classes (int): _description_
            predict_data_root (Path): _description_
            train_split (Path | None, optional): _description_. Defaults to None.
            val_split (Path | None, optional): _description_. Defaults to None.
            test_split (Path | None, optional): _description_. Defaults to None.
            ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
                file to determine which files to include in the dataset.
                E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
                actually ".jpg".
            allow_substring_split_file (bool, optional): Whether the split files contain substrings
                that must be present in file names to be included (as in mmsegmentation), or exact
                matches (e.g. eurosat). Defaults to True.
            dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
            constant_scale (float, optional): _description_. Defaults to 1.
            rgb_indices (list[int] | None, optional): _description_. Defaults to None.
            train_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            val_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            test_transform (Albumentations.Compose | None): Albumentations transform
                to be applied to the train dataset.
                Should end with ToTensorV2(). If used through the generic_data_module,
                should not include normalization. Not supported for multi-temporal data.
                Defaults to None, which simply applies ToTensorV2().
            no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to False.
            drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
        """
        super().__init__(GenericNonGeoClassificationDataset, batch_size, num_workers, **kwargs)
        self.num_classes = num_classes
        self.train_root = train_data_root
        self.val_root = val_data_root
        self.test_root = test_data_root
        self.predict_root = predict_data_root
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.ignore_split_file_extensions = ignore_split_file_extensions
        self.allow_substring_split_file = allow_substring_split_file
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.drop_last = drop_last

        self.dataset_bands = dataset_bands
        self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
        self.output_bands = output_bands
        self.rgb_indices = rgb_indices
        self.expand_temporal_dimension = expand_temporal_dimension

        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)

        # self.aug = AugmentationSequential(
        #     K.Normalize(means, stds),
        #     data_keys=["image"],
        # )

        means = load_from_file_or_attribute(means)
        stds = load_from_file_or_attribute(stds)

        self.aug = Normalize(means, stds)

        # self.aug = Normalize(means, stds)
        # self.collate_fn = collate_fn_list_dicts

    def setup(self, stage: str) -> None:
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                self.train_root,
                self.num_classes,
                split=self.train_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.train_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                self.val_root,
                self.num_classes,
                split=self.val_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.val_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                self.test_root,
                self.num_classes,
                split=self.test_split,
                ignore_split_file_extensions=self.ignore_split_file_extensions,
                allow_substring_split_file=self.allow_substring_split_file,
                dataset_bands=self.dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )
        if stage in ["predict"] and self.predict_root:
            self.predict_dataset = self.dataset_class(
                self.predict_root,
                self.num_classes,
                dataset_bands=self.predict_dataset_bands,
                output_bands=self.output_bands,
                constant_scale=self.constant_scale,
                rgb_indices=self.rgb_indices,
                transform=self.test_transform,
                no_data_replace=self.no_data_replace,
                expand_temporal_dimension=self.expand_temporal_dimension,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")

        batch_size = check_dataset_stackability(dataset, batch_size)

        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(batch_size, num_workers, train_data_root, val_data_root, test_data_root, means, stds, num_classes, predict_data_root=None, train_split=None, val_split=None, test_split=None, ignore_split_file_extensions=True, allow_substring_split_file=True, dataset_bands=None, predict_dataset_bands=None, output_bands=None, constant_scale=1, rgb_indices=None, train_transform=None, val_transform=None, test_transform=None, expand_temporal_dimension=False, no_data_replace=0, drop_last=True, **kwargs)

Constructor

Parameters:
  • batch_size (int) –

    description

  • num_workers (int) –

    description

  • train_data_root (Path) –

    description

  • val_data_root (Path) –

    description

  • test_data_root (Path) –

    description

  • means (list[float]) –

    description

  • stds (list[float]) –

    description

  • num_classes (int) –

    description

  • predict_data_root (Path, default: None ) –

    description

  • train_split (Path | None, default: None ) –

    description. Defaults to None.

  • val_split (Path | None, default: None ) –

    description. Defaults to None.

  • test_split (Path | None, default: None ) –

    description. Defaults to None.

  • ignore_split_file_extensions (bool, default: True ) –

    Whether to disregard extensions when using the split file to determine which files to include in the dataset. E.g. necessary for Eurosat, since the split files specify ".jpg" but files are actually ".jpg".

  • allow_substring_split_file (bool, default: True ) –

    Whether the split files contain substrings that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True.

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    description. Defaults to None.

  • predict_dataset_bands (list[HLSBands | int] | None, default: None ) –

    description. Defaults to None.

  • output_bands (list[HLSBands | int] | None, default: None ) –

    description. Defaults to None.

  • constant_scale (float, default: 1 ) –

    description. Defaults to 1.

  • rgb_indices (list[int] | None, default: None ) –

    description. Defaults to None.

  • train_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • val_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • test_transform (Compose | None, default: None ) –

    Albumentations transform to be applied to the train dataset. Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. Not supported for multi-temporal data. Defaults to None, which simply applies ToTensorV2().

  • no_data_replace (float, default: 0 ) –

    Replace nan values in input images with this value. Defaults to 0.

  • expand_temporal_dimension (bool, default: False ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to False.

  • drop_last (bool, default: True ) –

    Drop the last batch if it is not complete. Defaults to True.

Source code in terratorch/datamodules/generic_scalar_label_data_module.py
def __init__(
    self,
    batch_size: int,
    num_workers: int,
    train_data_root: Path,
    val_data_root: Path,
    test_data_root: Path,
    means: list[float] | str,
    stds: list[float] | str,
    num_classes: int,
    predict_data_root: Path | None = None,
    train_split: Path | None = None,
    val_split: Path | None = None,
    test_split: Path | None = None,
    ignore_split_file_extensions: bool = True,
    allow_substring_split_file: bool = True,
    dataset_bands: list[HLSBands | int] | None = None,
    predict_dataset_bands: list[HLSBands | int] | None = None,
    output_bands: list[HLSBands | int] | None = None,
    constant_scale: float = 1,
    rgb_indices: list[int] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    expand_temporal_dimension: bool = False,
    no_data_replace: float = 0,
    drop_last: bool = True,
    **kwargs: Any,
) -> None:
    """Constructor

    Args:
        batch_size (int): _description_
        num_workers (int): _description_
        train_data_root (Path): _description_
        val_data_root (Path): _description_
        test_data_root (Path): _description_
        means (list[float]): _description_
        stds (list[float]): _description_
        num_classes (int): _description_
        predict_data_root (Path): _description_
        train_split (Path | None, optional): _description_. Defaults to None.
        val_split (Path | None, optional): _description_. Defaults to None.
        test_split (Path | None, optional): _description_. Defaults to None.
        ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split
            file to determine which files to include in the dataset.
            E.g. necessary for Eurosat, since the split files specify ".jpg" but files are
            actually ".jpg".
        allow_substring_split_file (bool, optional): Whether the split files contain substrings
            that must be present in file names to be included (as in mmsegmentation), or exact
            matches (e.g. eurosat). Defaults to True.
        dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
        constant_scale (float, optional): _description_. Defaults to 1.
        rgb_indices (list[int] | None, optional): _description_. Defaults to None.
        train_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        val_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        test_transform (Albumentations.Compose | None): Albumentations transform
            to be applied to the train dataset.
            Should end with ToTensorV2(). If used through the generic_data_module,
            should not include normalization. Not supported for multi-temporal data.
            Defaults to None, which simply applies ToTensorV2().
        no_data_replace (float): Replace nan values in input images with this value. Defaults to 0.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to False.
        drop_last (bool): Drop the last batch if it is not complete. Defaults to True.
    """
    super().__init__(GenericNonGeoClassificationDataset, batch_size, num_workers, **kwargs)
    self.num_classes = num_classes
    self.train_root = train_data_root
    self.val_root = val_data_root
    self.test_root = test_data_root
    self.predict_root = predict_data_root
    self.train_split = train_split
    self.val_split = val_split
    self.test_split = test_split
    self.ignore_split_file_extensions = ignore_split_file_extensions
    self.allow_substring_split_file = allow_substring_split_file
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.drop_last = drop_last

    self.dataset_bands = dataset_bands
    self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
    self.output_bands = output_bands
    self.rgb_indices = rgb_indices
    self.expand_temporal_dimension = expand_temporal_dimension

    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)

    # self.aug = AugmentationSequential(
    #     K.Normalize(means, stds),
    #     data_keys=["image"],
    # )

    means = load_from_file_or_attribute(means)
    stds = load_from_file_or_attribute(stds)

    self.aug = Normalize(means, stds)

Custom datasets and data modules

Our custom datasets and data modules are crafted to handle specific data, offering enhanced control and flexibility throughout the workflow. In case you want to use Terratorch on your specific data, we invite you to develop your own dataset and data module classes by following the examples below.

Datasets

terratorch.datasets.biomassters

BioMasstersNonGeo

Bases: BioMassters

BioMassters Dataset for Aboveground Biomass prediction.

Dataset intended for Aboveground Biomass (AGB) prediction over Finnish forests based on Sentinel 1 and 2 data with corresponding target AGB mask values generated by Light Detection and Ranging (LiDAR).

Dataset Format:

  • .tif files for Sentinel 1 and 2 data
  • .tif file for pixel wise AGB target mask
  • .csv files for metadata regarding features and targets

Dataset Features:

  • 13,000 target AGB masks of size (256x256px)
  • 12 months of data per target mask
  • Sentinel 1 and Sentinel 2 data for each location
  • Sentinel 1 available for every month
  • Sentinel 2 available for almost every month (not available for every month due to ESA acquisition halt over the region during particular periods)

If you use this dataset in your research, please cite the following paper:

  • https://nascetti-a.github.io/BioMasster/

.. versionadded:: 0.5

Source code in terratorch/datasets/biomassters.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
class BioMasstersNonGeo(BioMassters):
    """[BioMassters Dataset](https://huggingface.co/datasets/ibm-nasa-geospatial/BioMassters) for Aboveground Biomass prediction.

    Dataset intended for Aboveground Biomass (AGB) prediction
    over Finnish forests based on Sentinel 1 and 2 data with
    corresponding target AGB mask values generated by Light Detection
    and Ranging (LiDAR).

    Dataset Format:

    * .tif files for Sentinel 1 and 2 data
    * .tif file for pixel wise AGB target mask
    * .csv files for metadata regarding features and targets

    Dataset Features:

    * 13,000 target AGB masks of size (256x256px)
    * 12 months of data per target mask
    * Sentinel 1 and Sentinel 2 data for each location
    * Sentinel 1 available for every month
    * Sentinel 2 available for almost every month
      (not available for every month due to ESA acquisition halt over the region
      during particular periods)

    If you use this dataset in your research, please cite the following paper:

    * https://nascetti-a.github.io/BioMasster/

    .. versionadded:: 0.5
    """

    S1_BAND_NAMES = ["VV_Asc", "VH_Asc", "VV_Desc", "VH_Desc", "RVI_Asc", "RVI_Desc"]
    S2_BAND_NAMES = [
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
        "CLOUD_PROBABILITY",
    ]

    all_band_names = {
        "S1": S1_BAND_NAMES,
        "S2": S2_BAND_NAMES,
    }

    rgb_bands = {
        "S1": [],
        "S2": ["RED", "GREEN", "BLUE"],
    }

    valid_splits = ("train", "test")
    valid_sensors = ("S1", "S2")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    default_metadata_filename = "The_BioMassters_-_features_metadata.csv.csv"

    def __init__(
        self,
        root = "data",
        split: str = "train",
        bands: dict[str, Sequence[str]] | Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        mask_mean: float | None = 63.4584,
        mask_std: float | None = 72.21242,
        sensors: Sequence[str] = ["S1", "S2"],
        as_time_series: bool = False,
        metadata_filename: str = default_metadata_filename,
        max_cloud_percentage: float | None = None,
        max_red_mean: float | None = None,
        include_corrupt: bool = True,
        subset: float = 1,
        seed: int = 42,
        use_four_frames: bool = False
    ) -> None:
        """Initialize a new instance of BioMassters dataset.

        If ``as_time_series=False`` (the default), each time step becomes its own
        sample with the target being shared across multiple samples.

        Args:
            root: root directory where dataset can be found
            split: train or test split
            sensors: which sensors to consider for the sample, Sentinel 1 and/or
                Sentinel 2 ('S1', 'S2')
            as_time_series: whether or not to return all available
                time-steps or just a single one for a given target location
            metadata_filename: metadata file to be used
            max_cloud_percentage: maximum allowed cloud percentage for images
            max_red_mean: maximum allowed red_mean value for images
            include_corrupt: whether to include images marked as corrupted

        Raises:
            AssertionError: if ``split`` or ``sensors`` is invalid
            DatasetNotFoundError: If dataset is not found.
        """
        self.root = root
        self.sensors = sensors
        self.bands = bands
        assert (
            split in self.valid_splits
        ), f"Please choose one of the valid splits: {self.valid_splits}."
        self.split = split

        assert set(sensors).issubset(
            set(self.valid_sensors)
        ), f"Please choose a subset of valid sensors: {self.valid_sensors}."

        if len(self.sensors) == 1:
            sens = self.sensors[0]
            self.band_indices = [
                self.all_band_names[sens].index(band) for band in self.bands[sens]
            ]
        else:
            self.band_indices = {
                sens: [self.all_band_names[sens].index(band) for band in self.bands[sens]]
                for sens in self.sensors
            }

        self.mask_mean = mask_mean
        self.mask_std = mask_std
        self.as_time_series = as_time_series
        self.metadata_filename = metadata_filename
        self.max_cloud_percentage = max_cloud_percentage
        self.max_red_mean = max_red_mean
        self.include_corrupt = include_corrupt
        self.subset = subset
        self.seed = seed
        self.use_four_frames = use_four_frames

        self._verify()

        # open metadata csv files
        self.df = pd.read_csv(os.path.join(self.root, self.metadata_filename))

        # Filter sensors
        self.df = self.df[self.df["satellite"].isin(self.sensors)]

        # Filter split
        self.df = self.df[self.df["split"] == self.split]

        # Optional filtering
        self._filter_and_select_data()

        # Optional subsampling
        self._random_subsample()

        # generate numerical month from filename since first month is September
        # and has numerical index of 0
        self.df["num_month"] = (
            self.df["filename"]
            .str.split("_", expand=True)[2]
            .str.split(".", expand=True)[0]
            .astype(int)
        )

        # Set dataframe index depending on the task for easier indexing
        if self.as_time_series:
            self.df["num_index"] = self.df.groupby(["chip_id"]).ngroup()
        else:
            filter_df = (
                self.df.groupby(["chip_id", "month"])["satellite"].count().reset_index()
            )
            filter_df = filter_df[
                filter_df["satellite"] == len(self.sensors)
            ].drop("satellite", axis=1)
            # Guarantee that each sample has corresponding number of images available
            self.df = self.df.merge(filter_df, on=["chip_id", "month"], how="inner")

            self.df["num_index"] = self.df.groupby(["chip_id", "month"]).ngroup()

        # Adjust transforms based on the number of sensors
        if len(self.sensors) == 1:
            self.transform = transform if transform else default_transform
        elif transform is None:
            self.transform = MultimodalToTensor(self.sensors)
        else:
            transform = {
                s: transform[s] if s in transform else default_transform
                for s in self.sensors
            }
            self.transform = MultimodalTransforms(transform, shared=False)

        if self.use_four_frames:
            self._select_4_frames()

    def __len__(self) -> int:
        return len(self.df["num_index"].unique())

    def _load_input(self, filenames: list[Path]) -> Tensor:
        """Load the input imagery at the index.

        Args:
            filenames: list of filenames corresponding to input

        Returns:
            input image
        """
        filepaths = [
            os.path.join(self.root, f"{self.split}_features", f) for f in filenames
        ]
        arr_list = [rasterio.open(fp).read() for fp in filepaths]

        if self.as_time_series:
            arr = np.stack(arr_list, axis=0) # (T, C, H, W)
        else:
            arr = np.concatenate(arr_list, axis=0)
        return arr.astype(np.int32)

    def _load_target(self, filename: Path) -> Tensor:
        """Load the target mask at the index.

        Args:
            filename: filename of target to index

        Returns:
            target mask
        """
        with rasterio.open(os.path.join(self.root, f"{self.split}_agbm", filename), "r") as src:
            arr: np.typing.NDArray[np.float64] = src.read()

        return arr

    def _compute_rvi(self, img: np.ndarray, linear: np.ndarray, sens: str) -> np.ndarray:
        """Compute the RVI indices for S1 data."""
        rvi_channels = []
        if self.as_time_series:
            if "RVI_Asc" in self.bands[sens]:
                try:
                    vv_asc_index = self.all_band_names["S1"].index("VV_Asc")
                    vh_asc_index = self.all_band_names["S1"].index("VH_Asc")
                except ValueError as e:
                    msg = f"RVI_Asc needs band: {e}"
                    raise ValueError(msg) from e

                VV = linear[:, vv_asc_index, :, :]
                VH = linear[:, vh_asc_index, :, :]
                rvi_asc = 4 * VH / (VV + VH + 1e-6)
                rvi_asc = np.expand_dims(rvi_asc, axis=1)
                rvi_channels.append(rvi_asc)
            if "RVI_Desc" in self.bands[sens]:
                try:
                    vv_desc_index = self.all_band_names["S1"].index("VV_Desc")
                    vh_desc_index = self.all_band_names["S1"].index("VH_Desc")
                except ValueError as e:
                    msg = f"RVI_Desc needs band: {e}"
                    raise ValueError(msg) from e

                VV_desc = linear[:, vv_desc_index, :, :]
                VH_desc = linear[:, vh_desc_index, :, :]
                rvi_desc = 4 * VH_desc / (VV_desc + VH_desc + 1e-6)
                rvi_desc = np.expand_dims(rvi_desc, axis=1)
                rvi_channels.append(rvi_desc)
            if rvi_channels:
                rvi_concat = np.concatenate(rvi_channels, axis=1)
                img = np.concatenate([img, rvi_concat], axis=1)
        else:
            if "RVI_Asc" in self.bands[sens]:
                if linear.shape[0] < 2:
                    msg = f"Not enough bands to calculate RVI_Asc. Available bands: {linear.shape[0]}"
                    raise ValueError(msg)
                VV = linear[0]
                VH = linear[1]
                rvi_asc = 4 * VH / (VV + VH + 1e-6)
                rvi_asc = np.expand_dims(rvi_asc, axis=0)
                rvi_channels.append(rvi_asc)
            if "RVI_Desc" in self.bands[sens]:
                if linear.shape[0] < 4:
                    msg = f"Not enough bands to calculate RVI_Desc. Available bands: {linear.shape[0]}"
                    raise ValueError(msg)
                VV_desc = linear[2]
                VH_desc = linear[3]
                rvi_desc = 4 * VH_desc / (VV_desc + VH_desc + 1e-6)
                rvi_desc = np.expand_dims(rvi_desc, axis=0) 
                rvi_channels.append(rvi_desc)
            if rvi_channels:
                rvi_concat = np.concatenate(rvi_channels, axis=0)
                img = np.concatenate([linear, rvi_concat], axis=0)
        return img

    def _select_4_frames(self):
        """Filter the dataset to select only 4 frames per sample."""

        if "cloud_percentage" in self.df.columns:
            self.df = self.df.sort_values(by=["chip_id", "cloud_percentage"])
        else:
            self.df = self.df.sort_values(by=["chip_id", "num_month"])

        self.df = (
            self.df.groupby("chip_id")
            .head(4)  # Select the first 4 frames per chip
            .reset_index(drop=True)
        )

    def _process_sensor_images(self, sens: str, sens_filepaths: list[str]) -> np.ndarray:
        """Process images for a given sensor."""
        img = self._load_input(sens_filepaths)
        if sens == "S1":
            img = img.astype(np.float32)
            linear = 10 ** (img / 10)
            img = self._compute_rvi(img, linear, sens)
        if self.as_time_series:
            img = img.transpose(0, 2, 3, 1)  # (T, H, W, C)
        else:
            img = img.transpose(1, 2, 0)  # (H, W, C)
        if len(self.sensors) == 1:
            img = img[..., self.band_indices]
        else:
            img = img[..., self.band_indices[sens]]
        return img

    def __getitem__(self, index: int) -> dict:
        sample_df = self.df[self.df["num_index"] == index].copy()
        # Sort by satellite and month
        sample_df.sort_values(
            by=["satellite", "num_month"], inplace=True, ascending=True
        )

        filepaths = sample_df["filename"].tolist()
        output = {}

        if len(self.sensors) == 1:
            sens = self.sensors[0]
            sens_filepaths = [fp for fp in filepaths if sens in fp]
            img = self._process_sensor_images(sens, sens_filepaths)
            output["image"] = img.astype(np.float32)
        else:
            for sens in self.sensors:
                sens_filepaths = [fp for fp in filepaths if sens in fp]
                img = self._process_sensor_images(sens, sens_filepaths)
                output[sens] = img.astype(np.float32)

        # Load target
        target_filename = sample_df["corresponding_agbm"].unique()[0]
        target = np.array(self._load_target(Path(target_filename)))
        target = target.transpose(1, 2, 0)
        output["mask"] = target
        if self.transform:
            if len(self.sensors) == 1:
                output = self.transform(**output)
            else:
                output = self.transform(output)
        output["mask"] = output["mask"].squeeze().float()
        return output

    def _filter_and_select_data(self):
        if (
            self.max_cloud_percentage is not None
            and "cloud_percentage" in self.df.columns
        ):
            self.df = self.df[self.df["cloud_percentage"] <= self.max_cloud_percentage]

        if self.max_red_mean is not None and "red_mean" in self.df.columns:
            self.df = self.df[self.df["red_mean"] <= self.max_red_mean]

        if not self.include_corrupt and "corrupt_values" in self.df.columns:
            self.df = self.df[self.df["corrupt_values"] is False]

    def _random_subsample(self):
        if self.split == "train" and self.subset < 1.0:
            num_samples = int(len(self.df["num_index"].unique()) * self.subset)
            if self.seed is not None:
                random.seed(self.seed)
            selected_indices = random.sample(
                list(self.df["num_index"].unique()), num_samples
            )
            self.df = self.df[self.df["num_index"].isin(selected_indices)]
            self.df.reset_index(drop=True, inplace=True)

    def plot(
        self,
        sample: dict[str, Tensor],
        show_titles: bool = True,
        suptitle: str | None = None,
    ) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            show_titles: flag indicating whether to show titles above each panel
            suptitle: optional suptitle to use for figure

        Returns:
            a matplotlib Figure with the rendered sample
        """
        # Determine if the sample contains multiple sensors or a single sensor
        if isinstance(sample["image"], dict):
            ncols = len(self.sensors) + 1
        else:
            ncols = 2  # One for the image and one for the mask

        showing_predictions = "prediction" in sample
        if showing_predictions:
            ncols += 1

        fig, axs = plt.subplots(1, ncols=ncols, figsize=(5 * ncols, 10))

        if isinstance(sample["image"], dict):
            # Multiple sensors case
            for idx, sens in enumerate(self.sensors):
                img = sample["image"][sens].numpy()
                if self.as_time_series:
                    # Plot last time step
                    img = img[:, -1, ...]
                if sens == "S2":
                    img = img[[2, 1, 0], ...].transpose(1, 2, 0)
                    img = percentile_normalization(img)
                else:
                    co_polarization = img[0]  # transmit == receive
                    cross_polarization = img[1]  # transmit != receive
                    ratio = co_polarization / (cross_polarization + 1e-6)

                    co_polarization = np.clip(co_polarization / 0.3, 0, 1)
                    cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
                    ratio = np.clip(ratio / 25, 0, 1)

                    img = np.stack(
                        (co_polarization, cross_polarization, ratio), axis=0
                    )
                    img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

                axs[idx].imshow(img)
                axs[idx].axis("off")
                if show_titles:
                    axs[idx].set_title(sens)
            mask_idx = len(self.sensors)
        else:
            # Single sensor case
            sens = self.sensors[0]
            img = sample["image"].numpy()
            if self.as_time_series:
                # Plot last time step
                img = img[:, -1, ...]
            if sens == "S2":
                img = img[[2, 1, 0], ...].transpose(1, 2, 0)
                img = percentile_normalization(img)
            else:
                co_polarization = img[0]  # transmit == receive
                cross_polarization = img[1]  # transmit != receive
                ratio = co_polarization / (cross_polarization + 1e-6)

                co_polarization = np.clip(co_polarization / 0.3, 0, 1)
                cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
                ratio = np.clip(ratio / 25, 0, 1)

                img = np.stack(
                    (co_polarization, cross_polarization, ratio), axis=0
                )
                img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

            axs[0].imshow(img)
            axs[0].axis("off")
            if show_titles:
                axs[0].set_title(sens)
            mask_idx = 1

        # Plot target mask
        if "mask" in sample:
            target = sample["mask"].squeeze()
            target_im = axs[mask_idx].imshow(target, cmap="YlGn")
            plt.colorbar(target_im, ax=axs[mask_idx], fraction=0.046, pad=0.04)
            axs[mask_idx].axis("off")
            if show_titles:
                axs[mask_idx].set_title("Target")

        # Plot prediction if available
        if showing_predictions:
            pred_idx = mask_idx + 1
            prediction = sample["prediction"].squeeze()
            pred_im = axs[pred_idx].imshow(prediction, cmap="YlGn")
            plt.colorbar(pred_im, ax=axs[pred_idx], fraction=0.046, pad=0.04)
            axs[pred_idx].axis("off")
            if show_titles:
                axs[pred_idx].set_title("Prediction")

        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(root='data', split='train', bands=BAND_SETS['all'], transform=None, mask_mean=63.4584, mask_std=72.21242, sensors=['S1', 'S2'], as_time_series=False, metadata_filename=default_metadata_filename, max_cloud_percentage=None, max_red_mean=None, include_corrupt=True, subset=1, seed=42, use_four_frames=False)

Initialize a new instance of BioMassters dataset.

If as_time_series=False (the default), each time step becomes its own sample with the target being shared across multiple samples.

Parameters:
  • root

    root directory where dataset can be found

  • split (str, default: 'train' ) –

    train or test split

  • sensors (Sequence[str], default: ['S1', 'S2'] ) –

    which sensors to consider for the sample, Sentinel 1 and/or Sentinel 2 ('S1', 'S2')

  • as_time_series (bool, default: False ) –

    whether or not to return all available time-steps or just a single one for a given target location

  • metadata_filename (str, default: default_metadata_filename ) –

    metadata file to be used

  • max_cloud_percentage (float | None, default: None ) –

    maximum allowed cloud percentage for images

  • max_red_mean (float | None, default: None ) –

    maximum allowed red_mean value for images

  • include_corrupt (bool, default: True ) –

    whether to include images marked as corrupted

Raises:
  • AssertionError

    if split or sensors is invalid

  • DatasetNotFoundError

    If dataset is not found.

Source code in terratorch/datasets/biomassters.py
def __init__(
    self,
    root = "data",
    split: str = "train",
    bands: dict[str, Sequence[str]] | Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    mask_mean: float | None = 63.4584,
    mask_std: float | None = 72.21242,
    sensors: Sequence[str] = ["S1", "S2"],
    as_time_series: bool = False,
    metadata_filename: str = default_metadata_filename,
    max_cloud_percentage: float | None = None,
    max_red_mean: float | None = None,
    include_corrupt: bool = True,
    subset: float = 1,
    seed: int = 42,
    use_four_frames: bool = False
) -> None:
    """Initialize a new instance of BioMassters dataset.

    If ``as_time_series=False`` (the default), each time step becomes its own
    sample with the target being shared across multiple samples.

    Args:
        root: root directory where dataset can be found
        split: train or test split
        sensors: which sensors to consider for the sample, Sentinel 1 and/or
            Sentinel 2 ('S1', 'S2')
        as_time_series: whether or not to return all available
            time-steps or just a single one for a given target location
        metadata_filename: metadata file to be used
        max_cloud_percentage: maximum allowed cloud percentage for images
        max_red_mean: maximum allowed red_mean value for images
        include_corrupt: whether to include images marked as corrupted

    Raises:
        AssertionError: if ``split`` or ``sensors`` is invalid
        DatasetNotFoundError: If dataset is not found.
    """
    self.root = root
    self.sensors = sensors
    self.bands = bands
    assert (
        split in self.valid_splits
    ), f"Please choose one of the valid splits: {self.valid_splits}."
    self.split = split

    assert set(sensors).issubset(
        set(self.valid_sensors)
    ), f"Please choose a subset of valid sensors: {self.valid_sensors}."

    if len(self.sensors) == 1:
        sens = self.sensors[0]
        self.band_indices = [
            self.all_band_names[sens].index(band) for band in self.bands[sens]
        ]
    else:
        self.band_indices = {
            sens: [self.all_band_names[sens].index(band) for band in self.bands[sens]]
            for sens in self.sensors
        }

    self.mask_mean = mask_mean
    self.mask_std = mask_std
    self.as_time_series = as_time_series
    self.metadata_filename = metadata_filename
    self.max_cloud_percentage = max_cloud_percentage
    self.max_red_mean = max_red_mean
    self.include_corrupt = include_corrupt
    self.subset = subset
    self.seed = seed
    self.use_four_frames = use_four_frames

    self._verify()

    # open metadata csv files
    self.df = pd.read_csv(os.path.join(self.root, self.metadata_filename))

    # Filter sensors
    self.df = self.df[self.df["satellite"].isin(self.sensors)]

    # Filter split
    self.df = self.df[self.df["split"] == self.split]

    # Optional filtering
    self._filter_and_select_data()

    # Optional subsampling
    self._random_subsample()

    # generate numerical month from filename since first month is September
    # and has numerical index of 0
    self.df["num_month"] = (
        self.df["filename"]
        .str.split("_", expand=True)[2]
        .str.split(".", expand=True)[0]
        .astype(int)
    )

    # Set dataframe index depending on the task for easier indexing
    if self.as_time_series:
        self.df["num_index"] = self.df.groupby(["chip_id"]).ngroup()
    else:
        filter_df = (
            self.df.groupby(["chip_id", "month"])["satellite"].count().reset_index()
        )
        filter_df = filter_df[
            filter_df["satellite"] == len(self.sensors)
        ].drop("satellite", axis=1)
        # Guarantee that each sample has corresponding number of images available
        self.df = self.df.merge(filter_df, on=["chip_id", "month"], how="inner")

        self.df["num_index"] = self.df.groupby(["chip_id", "month"]).ngroup()

    # Adjust transforms based on the number of sensors
    if len(self.sensors) == 1:
        self.transform = transform if transform else default_transform
    elif transform is None:
        self.transform = MultimodalToTensor(self.sensors)
    else:
        transform = {
            s: transform[s] if s in transform else default_transform
            for s in self.sensors
        }
        self.transform = MultimodalTransforms(transform, shared=False)

    if self.use_four_frames:
        self._select_4_frames()
plot(sample, show_titles=True, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • show_titles (bool, default: True ) –

    flag indicating whether to show titles above each panel

  • suptitle (str | None, default: None ) –

    optional suptitle to use for figure

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/biomassters.py
def plot(
    self,
    sample: dict[str, Tensor],
    show_titles: bool = True,
    suptitle: str | None = None,
) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        show_titles: flag indicating whether to show titles above each panel
        suptitle: optional suptitle to use for figure

    Returns:
        a matplotlib Figure with the rendered sample
    """
    # Determine if the sample contains multiple sensors or a single sensor
    if isinstance(sample["image"], dict):
        ncols = len(self.sensors) + 1
    else:
        ncols = 2  # One for the image and one for the mask

    showing_predictions = "prediction" in sample
    if showing_predictions:
        ncols += 1

    fig, axs = plt.subplots(1, ncols=ncols, figsize=(5 * ncols, 10))

    if isinstance(sample["image"], dict):
        # Multiple sensors case
        for idx, sens in enumerate(self.sensors):
            img = sample["image"][sens].numpy()
            if self.as_time_series:
                # Plot last time step
                img = img[:, -1, ...]
            if sens == "S2":
                img = img[[2, 1, 0], ...].transpose(1, 2, 0)
                img = percentile_normalization(img)
            else:
                co_polarization = img[0]  # transmit == receive
                cross_polarization = img[1]  # transmit != receive
                ratio = co_polarization / (cross_polarization + 1e-6)

                co_polarization = np.clip(co_polarization / 0.3, 0, 1)
                cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
                ratio = np.clip(ratio / 25, 0, 1)

                img = np.stack(
                    (co_polarization, cross_polarization, ratio), axis=0
                )
                img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

            axs[idx].imshow(img)
            axs[idx].axis("off")
            if show_titles:
                axs[idx].set_title(sens)
        mask_idx = len(self.sensors)
    else:
        # Single sensor case
        sens = self.sensors[0]
        img = sample["image"].numpy()
        if self.as_time_series:
            # Plot last time step
            img = img[:, -1, ...]
        if sens == "S2":
            img = img[[2, 1, 0], ...].transpose(1, 2, 0)
            img = percentile_normalization(img)
        else:
            co_polarization = img[0]  # transmit == receive
            cross_polarization = img[1]  # transmit != receive
            ratio = co_polarization / (cross_polarization + 1e-6)

            co_polarization = np.clip(co_polarization / 0.3, 0, 1)
            cross_polarization = np.clip(cross_polarization / 0.05, 0, 1)
            ratio = np.clip(ratio / 25, 0, 1)

            img = np.stack(
                (co_polarization, cross_polarization, ratio), axis=0
            )
            img = img.transpose(1, 2, 0)  # Convert to (H, W, 3)

        axs[0].imshow(img)
        axs[0].axis("off")
        if show_titles:
            axs[0].set_title(sens)
        mask_idx = 1

    # Plot target mask
    if "mask" in sample:
        target = sample["mask"].squeeze()
        target_im = axs[mask_idx].imshow(target, cmap="YlGn")
        plt.colorbar(target_im, ax=axs[mask_idx], fraction=0.046, pad=0.04)
        axs[mask_idx].axis("off")
        if show_titles:
            axs[mask_idx].set_title("Target")

    # Plot prediction if available
    if showing_predictions:
        pred_idx = mask_idx + 1
        prediction = sample["prediction"].squeeze()
        pred_im = axs[pred_idx].imshow(prediction, cmap="YlGn")
        plt.colorbar(pred_im, ax=axs[pred_idx], fraction=0.046, pad=0.04)
        axs[pred_idx].axis("off")
        if show_titles:
            axs[pred_idx].set_title("Prediction")

    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.burn_intensity

BurnIntensityNonGeo

Bases: NonGeoDataset

Dataset implementation for Burn Intensity classification.

Source code in terratorch/datasets/burn_intensity.py
class BurnIntensityNonGeo(NonGeoDataset):
    """Dataset implementation for [Burn Intensity classification](https://huggingface.co/datasets/ibm-nasa-geospatial/burn_intensity)."""

    all_band_names = (
        "BLUE", "GREEN", "RED", "NIR", "SWIR_1", "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    class_names = (
        "No burn",
        "Unburned to Very Low",
        "Low Severity",
        "Moderate Severity",
        "High Severity"
    )

    CSV_FILES = {
        "limited": "BS_files_with_less_than_25_percent_zeros.csv",
        "full": "BS_files_raw.csv",
    }

    num_classes = 5
    splits = {"train": "train", "val": "val"}
    time_steps = ["pre", "during", "post"]

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        use_full_data: bool = True,
        no_data_replace: float | None = 0.0001,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
    ) -> None:
        """Initialize the BurnIntensity dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train' or 'val'.
            bands (Sequence[str]): Bands to output. Defaults to all bands.
            transform (Optional[A.Compose]): Albumentations transform to be applied.
            use_metadata (bool): Whether to return metadata info (location).
            use_full_data (bool): Wheter to use full data or data with less than 25 percent zeros.
            no_data_replace (Optional[float]): Value to replace NaNs in images.
            no_label_replace (Optional[int]): Value to replace NaNs in labels.
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])

        self.data_root = Path(data_root)

        # Read the CSV file to get the list of cases to include
        csv_file_key = "full" if use_full_data else "limited"
        csv_path = self.data_root / self.CSV_FILES[csv_file_key]
        df = pd.read_csv(csv_path)
        casenames = df["Case_Name"].tolist()

        split_file = self.data_root / f"{split}.txt"
        with open(split_file) as f:
            split_images = [line.strip() for line in f.readlines()]

        split_images = [img for img in split_images if self._extract_casename(img) in casenames]

        # Build the samples list
        self.samples = []
        for image_filename in split_images:
            image_files = []
            for time_step in self.time_steps:
                image_file = self.data_root / time_step / image_filename
                image_files.append(str(image_file))
            mask_filename = image_filename.replace("HLS_", "BS_")
            mask_file = self.data_root / "pre" / mask_filename
            self.samples.append({
                "image_files": image_files,
                "mask_file": str(mask_file),
                "casename": self._extract_casename(image_filename),
            })

        self.use_metadata = use_metadata
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        self.transform = transform if transform else default_transform

    def _extract_basename(self, filepath: str) -> str:
        """Extract the base filename without extension."""
        return os.path.splitext(os.path.basename(filepath))[0]

    def _extract_casename(self, filename: str) -> str:
        """Extract the casename from the filename."""
        basename = self._extract_basename(filename)
        # Remove 'HLS_' or 'BS_' prefix
        casename = basename.replace("HLS_", "").replace("BS_", "")
        return casename

    def __len__(self) -> int:
        return len(self.samples)

    def _get_coords(self, image: DataArray) -> torch.Tensor:
        pixel_scale = image.rio.resolution()
        width, height = image.rio.width, image.rio.height

        left, bottom, right, top = image.rio.bounds()
        tie_point_x, tie_point_y = left, top

        center_col = width / 2
        center_row = height / 2

        center_lon = tie_point_x + (center_col * pixel_scale[0])
        center_lat = tie_point_y - (center_row * pixel_scale[1])

        lat_lon = np.asarray([center_lat, center_lon])
        return torch.tensor(lat_lon, dtype=torch.float32)

    def __getitem__(self, index: int) -> dict[str, Any]:
        sample = self.samples[index]
        image_files = sample["image_files"]
        mask_file = sample["mask_file"]

        images = []
        for idx, image_file in enumerate(image_files):
            image = self._load_file(Path(image_file), nan_replace=self.no_data_replace)
            if idx == 0 and self.use_metadata:
                location_coords = self._get_coords(image)
            image = image.to_numpy()
            image = np.moveaxis(image, 0, -1)
            image = image[..., self.band_indices]
            images.append(image)

        images = np.stack(images, axis=0)  # (T, H, W, C)

        output = {
            "image": images.astype(np.float32),
            "mask": self._load_file(Path(mask_file), nan_replace=self.no_label_replace).to_numpy()[0]
        }

        if self.transform:
            output = self.transform(**output)

        output["mask"] = output["mask"].long()
        if self.use_metadata:
            output["location_coords"] = location_coords

        return output

    def _load_file(self, path: Path, nan_replace: float | int | None = None) -> DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data


    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Any:
        """Plot a sample from the dataset.

        Args:
            sample: A sample returned by `__getitem__`.
            suptitle: Optional string to use as a suptitle.

        Returns:
            A matplotlib Figure with the rendered sample.
        """
        num_images = len(self.time_steps) + 2
        if "prediction" in sample:
            num_images += 1

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        images = sample["image"]  # (C, T, H, W)
        mask = sample["mask"].numpy()
        num_classes = len(np.unique(mask))

        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 5, 5))

        for i in range(len(self.time_steps)):
            image = images[:, i, :, :]  # (C, H, W)
            image = np.transpose(image, (1, 2, 0))  # (H, W, C)
            rgb_image = image[..., rgb_indices]
            rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
            rgb_image = np.clip(rgb_image, 0, 1)
            ax[i].imshow(rgb_image)
            ax[i].axis("off")
            ax[i].set_title(f"{self.time_steps[i].capitalize()} Image")

        cmap = plt.get_cmap("jet", num_classes)
        norm = Normalize(vmin=0, vmax=num_classes - 1)

        mask_ax_index = len(self.time_steps)
        ax[mask_ax_index].imshow(mask, cmap=cmap, norm=norm)
        ax[mask_ax_index].axis("off")
        ax[mask_ax_index].set_title("Ground Truth Mask")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            pred_ax_index = mask_ax_index + 1
            ax[pred_ax_index].imshow(prediction, cmap=cmap, norm=norm)
            ax[pred_ax_index].axis("off")
            ax[pred_ax_index].set_title("Predicted Mask")

        legend_ax_index = -1
        class_names = sample.get("class_names", self.class_names)
        positions = np.linspace(0, 1, num_classes) if num_classes > 1 else [0.5]

        legend_handles = [
            mpatches.Patch(color=cmap(pos), label=class_names[i])
            for i, pos in enumerate(positions)
        ]
        ax[legend_ax_index].legend(handles=legend_handles, loc="center")
        ax[legend_ax_index].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, use_full_data=True, no_data_replace=0.0001, no_label_replace=-1, use_metadata=False)

Initialize the BurnIntensity dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train' or 'val'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to output. Defaults to all bands.

  • transform (Optional[Compose], default: None ) –

    Albumentations transform to be applied.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info (location).

  • use_full_data (bool, default: True ) –

    Wheter to use full data or data with less than 25 percent zeros.

  • no_data_replace (Optional[float], default: 0.0001 ) –

    Value to replace NaNs in images.

  • no_label_replace (Optional[int], default: -1 ) –

    Value to replace NaNs in labels.

Source code in terratorch/datasets/burn_intensity.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    use_full_data: bool = True,
    no_data_replace: float | None = 0.0001,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
) -> None:
    """Initialize the BurnIntensity dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train' or 'val'.
        bands (Sequence[str]): Bands to output. Defaults to all bands.
        transform (Optional[A.Compose]): Albumentations transform to be applied.
        use_metadata (bool): Whether to return metadata info (location).
        use_full_data (bool): Wheter to use full data or data with less than 25 percent zeros.
        no_data_replace (Optional[float]): Value to replace NaNs in images.
        no_label_replace (Optional[int]): Value to replace NaNs in labels.
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])

    self.data_root = Path(data_root)

    # Read the CSV file to get the list of cases to include
    csv_file_key = "full" if use_full_data else "limited"
    csv_path = self.data_root / self.CSV_FILES[csv_file_key]
    df = pd.read_csv(csv_path)
    casenames = df["Case_Name"].tolist()

    split_file = self.data_root / f"{split}.txt"
    with open(split_file) as f:
        split_images = [line.strip() for line in f.readlines()]

    split_images = [img for img in split_images if self._extract_casename(img) in casenames]

    # Build the samples list
    self.samples = []
    for image_filename in split_images:
        image_files = []
        for time_step in self.time_steps:
            image_file = self.data_root / time_step / image_filename
            image_files.append(str(image_file))
        mask_filename = image_filename.replace("HLS_", "BS_")
        mask_file = self.data_root / "pre" / mask_filename
        self.samples.append({
            "image_files": image_files,
            "mask_file": str(mask_file),
            "casename": self._extract_casename(image_filename),
        })

    self.use_metadata = use_metadata
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by __getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Any

    A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/burn_intensity.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Any:
    """Plot a sample from the dataset.

    Args:
        sample: A sample returned by `__getitem__`.
        suptitle: Optional string to use as a suptitle.

    Returns:
        A matplotlib Figure with the rendered sample.
    """
    num_images = len(self.time_steps) + 2
    if "prediction" in sample:
        num_images += 1

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    images = sample["image"]  # (C, T, H, W)
    mask = sample["mask"].numpy()
    num_classes = len(np.unique(mask))

    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 5, 5))

    for i in range(len(self.time_steps)):
        image = images[:, i, :, :]  # (C, H, W)
        image = np.transpose(image, (1, 2, 0))  # (H, W, C)
        rgb_image = image[..., rgb_indices]
        rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
        rgb_image = np.clip(rgb_image, 0, 1)
        ax[i].imshow(rgb_image)
        ax[i].axis("off")
        ax[i].set_title(f"{self.time_steps[i].capitalize()} Image")

    cmap = plt.get_cmap("jet", num_classes)
    norm = Normalize(vmin=0, vmax=num_classes - 1)

    mask_ax_index = len(self.time_steps)
    ax[mask_ax_index].imshow(mask, cmap=cmap, norm=norm)
    ax[mask_ax_index].axis("off")
    ax[mask_ax_index].set_title("Ground Truth Mask")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        pred_ax_index = mask_ax_index + 1
        ax[pred_ax_index].imshow(prediction, cmap=cmap, norm=norm)
        ax[pred_ax_index].axis("off")
        ax[pred_ax_index].set_title("Predicted Mask")

    legend_ax_index = -1
    class_names = sample.get("class_names", self.class_names)
    positions = np.linspace(0, 1, num_classes) if num_classes > 1 else [0.5]

    legend_handles = [
        mpatches.Patch(color=cmap(pos), label=class_names[i])
        for i, pos in enumerate(positions)
    ]
    ax[legend_ax_index].legend(handles=legend_handles, loc="center")
    ax[legend_ax_index].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    plt.tight_layout()
    return fig

terratorch.datasets.carbonflux

CarbonFluxNonGeo

Bases: NonGeoDataset

Dataset for Carbon Flux regression from HLS images and MERRA data.

Source code in terratorch/datasets/carbonflux.py
class CarbonFluxNonGeo(NonGeoDataset):
    """Dataset for [Carbon Flux](https://huggingface.co/datasets/ibm-nasa-geospatial/hls_merra2_gppFlux) regression from HLS images and MERRA data."""

    all_band_names = (
        "BLUE", "GREEN", "RED", "NIR", "SWIR_1", "SWIR_2",
    )

    rgb_bands = (
        "RED", "GREEN", "BLUE",
    )

    merra_var_names = (
        "T2MIN", "T2MAX", "T2MEAN", "TSMDEWMEAN", "GWETROOT",
        "LHLAND", "SHLAND", "SWLAND", "PARDFLAND", "PRECTOTLAND"
    )

    splits = {"train": "train", "test": "test"}

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    metadata_file = "data_train_hls_37sites_v0_1.csv"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        gpp_mean: float | None = None,
        gpp_std: float | None = None,
        no_data_replace: float | None = 0.0001,
        use_metadata: bool = False,
        modalities: Sequence[str] = ("image", "merra_vars")
    ) -> None:
        """Initialize the CarbonFluxNonGeo dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): 'train' or 'test'.
            bands (Sequence[str]): Bands to use. Defaults to all bands.
            transform (Optional[A.Compose]): Albumentations transform to be applied.
            use_metadata (bool): Whether to return metadata (coordinates and date).
            merra_means (Sequence[float]): Means for MERRA data normalization.
            merra_stds (Sequence[float]): Standard deviations for MERRA data normalization.
            gpp_mean (float): Mean for GPP normalization.
            gpp_std (float): Standard deviation for GPP normalization.
            no_data_replace (Optional[float]): Value to replace NO_DATA values in images.
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)

        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(band) for band in bands]

        self.data_root = Path(data_root)

        # Load the CSV file with metadata
        csv_file = self.data_root / self.metadata_file
        df = pd.read_csv(csv_file)

        # Get list of image filenames in the split directory
        image_dir = self.data_root / self.split
        image_files = [f.name for f in image_dir.glob("*.tiff")]

        df["Chip"] = df["Chip"].str.replace(".tif$", ".tiff", regex=True)
        # Filter the DataFrame to include only rows with 'Chip' in image_files
        df = df[df["Chip"].isin(image_files)]

        # Build the samples list
        self.samples = []
        for _, row in df.iterrows():
            image_filename = row["Chip"]
            image_path = image_dir / image_filename
            # MERRA vectors
            merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
            # GPP target
            gpp = row["GPP"]

            image_path = image_dir / row["Chip"]
            merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
            gpp = row["GPP"]
            self.samples.append({
                "image_path": str(image_path),
                "merra_vars": merra_vars,
                "gpp": gpp,
            })

        if gpp_mean is None or gpp_std is None:
            msg = "Mean and standard deviation for GPP must be provided."
            raise ValueError(msg)
        self.gpp_mean = gpp_mean
        self.gpp_std = gpp_std

        self.use_metadata = use_metadata
        self.modalities = modalities
        self.no_data_replace = no_data_replace

        if transform is None:
            self.transform = MultimodalToTensor(self.modalities)
        else:
            transform = {m: transform[m] if m in transform else default_transform
                for m in self.modalities}
            self.transform = MultimodalTransforms(transform, shared=False)

    def __len__(self) -> int:
        return len(self.samples)

    def _load_file(self, path: str, nan_replace: float | int | None = None):
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def _get_coords(self, image) -> torch.Tensor:
        """Extract the center coordinates from the image geospatial metadata."""
        pixel_scale = image.rio.resolution()
        width, height = image.rio.width, image.rio.height

        left, bottom, right, top = image.rio.bounds()
        tie_point_x, tie_point_y = left, top

        center_col = width / 2
        center_row = height / 2

        center_lon = tie_point_x + (center_col * pixel_scale[0])
        center_lat = tie_point_y - (center_row * pixel_scale[1])

        src_crs = image.rio.crs
        dst_crs = "EPSG:4326"

        transformer = pyproj.Transformer.from_crs(src_crs, dst_crs, always_xy=True)
        lon, lat = transformer.transform(center_lon, center_lat)

        coords = np.array([lat, lon], dtype=np.float32)
        return torch.from_numpy(coords)

    def _get_date(self, filename: str) -> torch.Tensor:
        """Extract the date from the filename."""
        base_filename = os.path.basename(filename)
        pattern = r"HLS\..{3}\.[A-Z0-9]{6}\.(?P<date>\d{7}T\d{6})\..*\.tiff$"
        match = re.match(pattern, base_filename)
        if not match:
            msg = f"Filename {filename} does not match expected pattern."
            raise ValueError(msg)

        date_str = match.group("date")
        year = int(date_str[:4])
        julian_day = int(date_str[4:7])

        date_tensor = torch.tensor([year, julian_day], dtype=torch.int32)
        return date_tensor

    def __getitem__(self, idx: int) -> dict[str, Any]:
        sample = self.samples[idx]
        image_path = sample["image_path"]

        image = self._load_file(image_path, nan_replace=self.no_data_replace)

        if self.use_metadata:
            location_coords = self._get_coords(image)
            temporal_coords = self._get_date(os.path.basename(image_path))

        image = image.to_numpy()  # (C, H, W)
        image = image[self.band_indices, ...]
        image = np.moveaxis(image, 0, -1) # (H, W, C)

        merra_vars = np.array(sample["merra_vars"])
        target = np.array(sample["gpp"])
        target_norm = (target - self.gpp_mean) / self.gpp_std
        target_norm = torch.tensor(target_norm, dtype=torch.float32)
        output = {
            "image": image.astype(np.float32),
            "merra_vars": merra_vars,
        }

        if self.transform:
            output = self.transform(output)

        output = {
            "image": {m: output[m] for m in self.modalities if m in output},
            "mask": target_norm
        }
        if self.use_metadata:
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def plot(self, sample: dict[str, Any], suptitle: str | None = None) -> Any:
        """Plot a sample from the dataset.

        Args:
            sample: A sample returned by `__getitem__`.
            suptitle: Optional title for the figure.

        Returns:
            A matplotlib figure with the rendered sample.
        """
        image = sample["image"].numpy()

        image = np.transpose(image, (1, 2, 0))  # (H, W, C)

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        rgb_image = image[..., rgb_indices]

        rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
        rgb_image = np.clip(rgb_image, 0, 1)

        fig, ax = plt.subplots(1, 1, figsize=(6, 6))
        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title("Image")

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, gpp_mean=None, gpp_std=None, no_data_replace=0.0001, use_metadata=False, modalities=('image', 'merra_vars'))

Initialize the CarbonFluxNonGeo dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    'train' or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to use. Defaults to all bands.

  • transform (Optional[Compose], default: None ) –

    Albumentations transform to be applied.

  • use_metadata (bool, default: False ) –

    Whether to return metadata (coordinates and date).

  • merra_means (Sequence[float]) –

    Means for MERRA data normalization.

  • merra_stds (Sequence[float]) –

    Standard deviations for MERRA data normalization.

  • gpp_mean (float, default: None ) –

    Mean for GPP normalization.

  • gpp_std (float, default: None ) –

    Standard deviation for GPP normalization.

  • no_data_replace (Optional[float], default: 0.0001 ) –

    Value to replace NO_DATA values in images.

Source code in terratorch/datasets/carbonflux.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    gpp_mean: float | None = None,
    gpp_std: float | None = None,
    no_data_replace: float | None = 0.0001,
    use_metadata: bool = False,
    modalities: Sequence[str] = ("image", "merra_vars")
) -> None:
    """Initialize the CarbonFluxNonGeo dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): 'train' or 'test'.
        bands (Sequence[str]): Bands to use. Defaults to all bands.
        transform (Optional[A.Compose]): Albumentations transform to be applied.
        use_metadata (bool): Whether to return metadata (coordinates and date).
        merra_means (Sequence[float]): Means for MERRA data normalization.
        merra_stds (Sequence[float]): Standard deviations for MERRA data normalization.
        gpp_mean (float): Mean for GPP normalization.
        gpp_std (float): Standard deviation for GPP normalization.
        no_data_replace (Optional[float]): Value to replace NO_DATA values in images.
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)

    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(band) for band in bands]

    self.data_root = Path(data_root)

    # Load the CSV file with metadata
    csv_file = self.data_root / self.metadata_file
    df = pd.read_csv(csv_file)

    # Get list of image filenames in the split directory
    image_dir = self.data_root / self.split
    image_files = [f.name for f in image_dir.glob("*.tiff")]

    df["Chip"] = df["Chip"].str.replace(".tif$", ".tiff", regex=True)
    # Filter the DataFrame to include only rows with 'Chip' in image_files
    df = df[df["Chip"].isin(image_files)]

    # Build the samples list
    self.samples = []
    for _, row in df.iterrows():
        image_filename = row["Chip"]
        image_path = image_dir / image_filename
        # MERRA vectors
        merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
        # GPP target
        gpp = row["GPP"]

        image_path = image_dir / row["Chip"]
        merra_vars = row[list(self.merra_var_names)].values.astype(np.float32)
        gpp = row["GPP"]
        self.samples.append({
            "image_path": str(image_path),
            "merra_vars": merra_vars,
            "gpp": gpp,
        })

    if gpp_mean is None or gpp_std is None:
        msg = "Mean and standard deviation for GPP must be provided."
        raise ValueError(msg)
    self.gpp_mean = gpp_mean
    self.gpp_std = gpp_std

    self.use_metadata = use_metadata
    self.modalities = modalities
    self.no_data_replace = no_data_replace

    if transform is None:
        self.transform = MultimodalToTensor(self.modalities)
    else:
        transform = {m: transform[m] if m in transform else default_transform
            for m in self.modalities}
        self.transform = MultimodalTransforms(transform, shared=False)
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Any]) –

    A sample returned by __getitem__.

  • suptitle (str | None, default: None ) –

    Optional title for the figure.

Returns:
  • Any

    A matplotlib figure with the rendered sample.

Source code in terratorch/datasets/carbonflux.py
def plot(self, sample: dict[str, Any], suptitle: str | None = None) -> Any:
    """Plot a sample from the dataset.

    Args:
        sample: A sample returned by `__getitem__`.
        suptitle: Optional title for the figure.

    Returns:
        A matplotlib figure with the rendered sample.
    """
    image = sample["image"].numpy()

    image = np.transpose(image, (1, 2, 0))  # (H, W, C)

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    rgb_image = image[..., rgb_indices]

    rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
    rgb_image = np.clip(rgb_image, 0, 1)

    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title("Image")

    if suptitle:
        plt.suptitle(suptitle)

    plt.tight_layout()
    return fig

terratorch.datasets.forestnet

ForestNetNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for ForestNet.

Source code in terratorch/datasets/forestnet.py
class ForestNetNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [ForestNet](https://huggingface.co/datasets/ibm-nasa-geospatial/ForestNet)."""

    all_band_names = (
        "RED", "GREEN", "BLUE", "NIR", "SWIR_1", "SWIR_2"
    )

    rgb_bands = (
        "RED", "GREEN", "BLUE",
    )

    splits = ("train", "test", "val")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    default_label_map = {  # noqa: RUF012
        "Plantation": 0,
        "Smallholder agriculture": 1,
        "Grassland shrubland": 2,
        "Other": 3,
    }

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        label_map: dict[str, int] = default_label_map,
        transform: A.Compose | None = None,
        fraction: float = 1.0,
        bands: Sequence[str] = BAND_SETS["all"],
        use_metadata: bool = False,
    ) -> None:
        """
        Initialize the ForestNetNonGeo dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            label_map (Dict[str, int]): Mapping from label names to integer labels.
            transform: Transformations to be applied to the images.
            fraction (float): Fraction of the dataset to use. Defaults to 1.0 (use all data).
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits)}."
            raise ValueError(msg)
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.label_map = label_map

        # Load the CSV file corresponding to the split
        csv_file = self.data_root / f"{split}_filtered.csv"
        original_df = pd.read_csv(csv_file)

        # Apply stratified sampling if fraction < 1.0
        if fraction < 1.0:
            sss = StratifiedShuffleSplit(n_splits=1, test_size=1 - fraction, random_state=47)
            stratified_indices, _ = next(sss.split(original_df, original_df["merged_label"]))
            self.dataset = original_df.iloc[stratified_indices].reset_index(drop=True)
        else:
            self.dataset = original_df

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.dataset)

    def _get_coords(self, event_path: Path) -> torch.Tensor:
        auxiliary_path = event_path / "auxiliary"
        osm_json_path = auxiliary_path / "osm.json"

        with open(osm_json_path) as f:
            osm_data = json.load(f)
            lat = float(osm_data["closest_city"]["lat"])
            lon = float(osm_data["closest_city"]["lon"])
            lat_lon = np.asarray([lat, lon])

        return torch.tensor(lat_lon, dtype=torch.float32)

    def _get_dates(self, image_files: list) -> list:
        dates = []
        pattern = re.compile(r"(\d{4})_(\d{2})_(\d{2})_cloud_\d+\.(png|npy)")
        for img_path in image_files:
            match = pattern.search(img_path)
            year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3))
            date_obj = datetime.datetime(year, month, day)  # noqa: DTZ001
            julian_day = date_obj.timetuple().tm_yday
            date_tensor = torch.tensor([year, julian_day], dtype=torch.int32)
            dates.append(date_tensor)
        return torch.stack(dates, dim=0)

    def __getitem__(self, index: int):
        path = self.data_root / self.dataset["example_path"][index]
        label = self.map_label(index)

        visible_images, infrared_images, temporal_coords = self._load_images(path)

        visible_images = np.stack(visible_images, axis=0)
        infrared_images = np.stack(infrared_images, axis=0)
        merged_images = np.concatenate([visible_images, infrared_images], axis=-1)
        merged_images = merged_images[..., self.band_indices] # (T, H, W, 2C)
        output = {
            "image": merged_images.astype(np.float32)
        }

        if self.transform:
            output = self.transform(**output)

        if self.use_metadata:
            location_coords = self._get_coords(path)
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        output["label"] = label

        return output

    def _load_images(self, path: str):
        """Load visible and infrared images from the given event path"""
        visible_image_files = glob.glob(os.path.join(path, "images/visible/*_cloud_*.png"))
        infra_image_files = glob.glob(os.path.join(path, "images/infrared/*_cloud_*.npy"))

        selected_visible_images = self.select_images(visible_image_files)
        selected_infra_images = self.select_images(infra_image_files)

        dates = None
        if self.use_metadata:
            dates = self._get_dates(selected_visible_images)

        vis_images = [np.array(Image.open(img)) for img in selected_visible_images] # (T, H, W, C)
        inf_images = [np.load(img, allow_pickle=True) for img in selected_infra_images] # (T, H, W, C)
        return vis_images, inf_images, dates

    def least_cloudy_image(self, image_files):
        pattern = re.compile(r"(\d{4})_\d{2}_\d{2}_cloud_(\d+)\.(png|npy)")
        lowest_cloud_images = defaultdict(lambda: {"path": None, "cloud_value": float("inf")})

        for path in image_files:
            match = pattern.search(path)
            if match:
                year, cloud_value = match.group(1), int(match.group(2))
                if cloud_value < lowest_cloud_images[year]["cloud_value"]:
                    lowest_cloud_images[year] = {"path": path, "cloud_value": cloud_value}

        return [info["path"] for info in lowest_cloud_images.values()]

    def match_timesteps(self, image_files, selected_images):
        if len(selected_images) < 3:
            extra_imgs = [img for img in image_files if img not in selected_images]
            selected_images += extra_imgs[:3 - len(selected_images)]

        while len(selected_images) < 3:
            selected_images.append(selected_images[-1])
        return selected_images[:3]

    def select_images(self, image_files):
        selected = self.least_cloudy_image(image_files)
        return self.match_timesteps(image_files, selected)

    def map_label(self, index: int) -> torch.Tensor:
        """Map the label name to an integer label."""
        label_name = self.dataset["merged_label"][index]
        label = self.label_map[label_name]
        return label

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None):

        num_images = sample["image"].shape[1] + 1

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        fig, ax = plt.subplots(1, num_images, figsize=(15, 5))

        for i in range(sample["image"].shape[1]):
            image = sample["image"][:, i, :, :]
            if torch.is_tensor(image):
                image = image.permute(1, 2, 0).numpy()
            rgb_image = image[..., rgb_indices]
            rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8)
            rgb_image = np.clip(rgb_image, 0, 1)
            ax[i].imshow(rgb_image)
            ax[i].axis("off")
            ax[i].set_title(f"Timestep {i + 1}")

        legend_handles = [Rectangle((0, 0), 1, 1, color="blue")]
        legend_label = [self.label_map.get(sample["label"], "Unknown Label")]
        ax[-1].legend(legend_handles, legend_label, loc="center")
        ax[-1].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        return fig
__init__(data_root, split='train', label_map=default_label_map, transform=None, fraction=1.0, bands=BAND_SETS['all'], use_metadata=False)

Initialize the ForestNetNonGeo dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • label_map (Dict[str, int], default: default_label_map ) –

    Mapping from label names to integer labels.

  • transform (Compose | None, default: None ) –

    Transformations to be applied to the images.

  • fraction (float, default: 1.0 ) –

    Fraction of the dataset to use. Defaults to 1.0 (use all data).

Source code in terratorch/datasets/forestnet.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    label_map: dict[str, int] = default_label_map,
    transform: A.Compose | None = None,
    fraction: float = 1.0,
    bands: Sequence[str] = BAND_SETS["all"],
    use_metadata: bool = False,
) -> None:
    """
    Initialize the ForestNetNonGeo dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        label_map (Dict[str, int]): Mapping from label names to integer labels.
        transform: Transformations to be applied to the images.
        fraction (float): Fraction of the dataset to use. Defaults to 1.0 (use all data).
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits)}."
        raise ValueError(msg)
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.label_map = label_map

    # Load the CSV file corresponding to the split
    csv_file = self.data_root / f"{split}_filtered.csv"
    original_df = pd.read_csv(csv_file)

    # Apply stratified sampling if fraction < 1.0
    if fraction < 1.0:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=1 - fraction, random_state=47)
        stratified_indices, _ = next(sss.split(original_df, original_df["merged_label"]))
        self.dataset = original_df.iloc[stratified_indices].reset_index(drop=True)
    else:
        self.dataset = original_df

    self.transform = transform if transform else default_transform
map_label(index)

Map the label name to an integer label.

Source code in terratorch/datasets/forestnet.py
def map_label(self, index: int) -> torch.Tensor:
    """Map the label name to an integer label."""
    label_name = self.dataset["merged_label"][index]
    label = self.label_map[label_name]
    return label

terratorch.datasets.fire_scars

FireScarsHLS

Bases: RasterDataset

RasterDataset implementation for fire scars input images.

Source code in terratorch/datasets/fire_scars.py
class FireScarsHLS(RasterDataset):
    """RasterDataset implementation for fire scars input images."""

    filename_glob = "subsetted*_merged.tif"
    filename_regex = r"subsetted_512x512_HLS\..30\..{6}\.(?P<date>[0-9]*)\.v1.4_merged.tif"
    date_format = "%Y%j"
    is_image = True
    separate_files = False
    all_bands = dataclasses.field(default_factory=["B02", "B03", "B04", "B8A", "B11", "B12"])
    rgb_bands = dataclasses.field(default_factory=["B04", "B03", "B02"])
FireScarsNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for fire scars.

Source code in terratorch/datasets/fire_scars.py
class FireScarsNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [fire scars](https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars)."""
    all_band_names = (
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    num_classes = 2
    splits = {"train": "training", "val": "validation"}   # Only train and val splits available

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (str): Path to the data root directory.
            bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the corresponding data module,
                should not include normalization. Defaults to None, which applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value.
                If None, does no replacement. Defaults to 0.
            no_label_replace (int | None): Replace nan values in label with this value.
                If none, does no replacement. Defaults to -1.
            use_metadata (bool): whether to return metadata info (time and location).
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {self.splits}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
        self.data_root = Path(data_root)

        input_dir = self.data_root / split_name
        self.image_files = sorted(glob.glob(os.path.join(input_dir, "*_merged.tif")))
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(input_dir, "*.mask.tif")))

        self.use_metadata = use_metadata
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def _get_date(self, index: int) -> torch.Tensor:
        file_name = self.image_files[index]
        base_filename = os.path.basename(file_name)

        filename_regex = r"subsetted_512x512_HLS\.S30\.T[0-9A-Z]{5}\.(?P<date>[0-9]+)\.v1\.4_merged\.tif"
        match = re.match(filename_regex, base_filename)
        date_str = match.group("date")
        year = int(date_str[:4])
        julian_day = int(date_str[4:])

        return torch.tensor([[year, julian_day]], dtype=torch.float32)

    def _get_coords(self, image: DataArray) -> torch.Tensor:
        px = image.x.shape[0] // 2
        py = image.y.shape[0] // 2

        # get center point to reproject to lat/lon
        point = image.isel(band=0, x=slice(px, px + 1), y=slice(py, py + 1))
        point = point.rio.reproject("epsg:4326")

        lat_lon = np.asarray([point.y[0], point.x[0]])

        return torch.tensor(lat_lon, dtype=torch.float32)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace)

        location_coords, temporal_coords = None, None
        if self.use_metadata:
            location_coords = self._get_coords(image)
            temporal_coords = self._get_date(index)

        # to channels last
        image = image.to_numpy()
        image = np.moveaxis(image, 0, -1)

        # filter bands
        image = image[..., self.band_indices]

        output = {
            "image": image.astype(np.float32),
            "mask": self._load_file(
                self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[0],
        }
        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def _load_file(self, path: Path, nan_replace: int | float | None = None) -> DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample
        """
        num_images = 4

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        # RGB -> channels-last
        image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
        mask = sample["mask"].numpy()

        image = clip_image_percentile(image)

        if "prediction" in sample:
            prediction = sample["prediction"]
            num_images += 1
        else:
            prediction = None

        fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
        ax[1].axis("off")
        ax[1].title.set_text("Image")
        ax[1].imshow(image)

        ax[2].axis("off")
        ax[2].title.set_text("Ground Truth Mask")
        ax[2].imshow(mask, cmap="jet", norm=norm)

        ax[3].axis("off")
        ax[3].title.set_text("GT Mask on Image")
        ax[3].imshow(image)
        ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

        if "prediction" in sample:
            ax[4].title.set_text("Predicted Mask")
            ax[4].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")
        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, no_data_replace=0, no_label_replace=-1, use_metadata=False)

Constructor

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • bands (list[str], default: BAND_SETS['all'] ) –

    Bands that should be output by the dataset. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the corresponding data module, should not include normalization. Defaults to None, which applies ToTensorV2().

  • no_data_replace (float | None, default: 0 ) –

    Replace nan values in input images with this value. If None, does no replacement. Defaults to 0.

  • no_label_replace (int | None, default: -1 ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to -1.

  • use_metadata (bool, default: False ) –

    whether to return metadata info (time and location).

Source code in terratorch/datasets/fire_scars.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (str): Path to the data root directory.
        bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the corresponding data module,
            should not include normalization. Defaults to None, which applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value.
            If None, does no replacement. Defaults to 0.
        no_label_replace (int | None): Replace nan values in label with this value.
            If none, does no replacement. Defaults to -1.
        use_metadata (bool): whether to return metadata info (time and location).
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {self.splits}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
    self.data_root = Path(data_root)

    input_dir = self.data_root / split_name
    self.image_files = sorted(glob.glob(os.path.join(input_dir, "*_merged.tif")))
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(input_dir, "*.mask.tif")))

    self.use_metadata = use_metadata
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • suptitle (str | None, default: None ) –

    optional string to use as a suptitle

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/fire_scars.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample
    """
    num_images = 4

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    # RGB -> channels-last
    image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
    mask = sample["mask"].numpy()

    image = clip_image_percentile(image)

    if "prediction" in sample:
        prediction = sample["prediction"]
        num_images += 1
    else:
        prediction = None

    fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

    ax[0].axis("off")

    norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
    ax[1].axis("off")
    ax[1].title.set_text("Image")
    ax[1].imshow(image)

    ax[2].axis("off")
    ax[2].title.set_text("Ground Truth Mask")
    ax[2].imshow(mask, cmap="jet", norm=norm)

    ax[3].axis("off")
    ax[3].title.set_text("GT Mask on Image")
    ax[3].imshow(image)
    ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

    if "prediction" in sample:
        ax[4].title.set_text("Predicted Mask")
        ax[4].imshow(prediction, cmap="jet", norm=norm)

    cmap = plt.get_cmap("jet")
    legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
    handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
    labels = [n for k, c, n in legend_data]
    ax[0].legend(handles, labels, loc="center")
    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig
FireScarsSegmentationMask

Bases: RasterDataset

RasterDataset implementation for fire scars segmentation mask. Can be easily merged with input images using the & operator.

Source code in terratorch/datasets/fire_scars.py
class FireScarsSegmentationMask(RasterDataset):
    """RasterDataset implementation for fire scars segmentation mask.
    Can be easily merged with input images using the & operator.
    """

    filename_glob = "subsetted*.mask.tif"
    filename_regex = r"subsetted_512x512_HLS\..30\..{6}\.(?P<date>[0-9]*)\.v1.4.mask.tif"
    date_format = "%Y%j"
    is_image = False
    separate_files = False

terratorch.datasets.landslide4sense

Landslide4SenseNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for Landslide4Sense.

Source code in terratorch/datasets/landslide4sense.py
class Landslide4SenseNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [Landslide4Sense](https://huggingface.co/datasets/ibm-nasa-geospatial/Landslide4sense)."""
    all_band_names = (
        "COASTAL AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "WATER_VAPOR",
        "CIRRUS",
        "SWIR_1",
        "SWIR_2",
        "SLOPE",
        "DEM",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")
    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "validation", "test": "test"}


    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
    ) -> None:
        """Initialize the Landslide4Sense dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'validation', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]

        self.data_directory = Path(data_root)

        images_dir = self.data_directory / "images" / split_name
        annotations_dir = self.data_directory / "annotations" / split_name

        self.image_files = sorted(images_dir.glob("image_*.h5"))
        self.mask_files = sorted(annotations_dir.glob("mask_*.h5"))

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        image_file = self.image_files[index]
        mask_file = self.mask_files[index]

        with h5py.File(image_file, "r") as h5file:
            image = np.array(h5file["img"])[..., self.band_indices]

        with h5py.File(mask_file, "r") as h5file:
            mask = np.array(h5file["mask"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()
        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]

        rgb_image = (rgb_image - rgb_image.min(axis=(0, 1))) * (1 / rgb_image.max(axis=(0, 1)))
        rgb_image = np.clip(rgb_image, 0, 1)

        num_classes = len(np.unique(mask))
        cmap = colormaps["jet"]
        norm = Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"]
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if sample.get("class_names"):
            class_names = sample["class_names"]
            legend_handles = [
                mpatches.Patch(color=cmap(i), label=class_names[i]) for i in range(num_classes)
            ]
            ax[0].legend(handles=legend_handles, bbox_to_anchor=(1.05, 1), loc="upper left")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None)

Initialize the Landslide4Sense dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'validation', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

Source code in terratorch/datasets/landslide4sense.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
) -> None:
    """Initialize the Landslide4Sense dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'validation', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]

    self.data_directory = Path(data_root)

    images_dir = self.data_directory / "images" / split_name
    annotations_dir = self.data_directory / "annotations" / split_name

    self.image_files = sorted(images_dir.glob("image_*.h5"))
    self.mask_files = sorted(annotations_dir.glob("mask_*.h5"))

    self.transform = transform if transform else default_transform

terratorch.datasets.m_eurosat

MEuroSATNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-EuroSAT.

Source code in terratorch/datasets/m_eurosat.py
class MEuroSATNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-EuroSAT](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "COASTAL_AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "WATER_VAPOR",
        "CIRRUS",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-eurosat"
    partition_file_template = "{partition}_partition.json"
    label_map_file = "label_map.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]\

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        label_map_path = self.data_directory / self.label_map_file
        with open(label_map_path) as file:
            self.label_map = json.load(file)

        self.id_to_class = {img_id: cls for cls, ids in self.label_map.items() for img_id in ids}

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)

        label_class = self.id_to_class[image_id]
        label_index = list(self.label_map.keys()).index(label_class)

        output = {"image": image.astype(np.float32)}

        if self.transform:
            output = self.transform(**output)

        output["label"] = label_index

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label_index = sample["label"]

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        class_names = list(self.label_map.keys())
        class_name = class_names[label_index]

        fig, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Class: {class_name}")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default')

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

Source code in terratorch/datasets/m_eurosat.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]\

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    label_map_path = self.data_directory / self.label_map_file
    with open(label_map_path) as file:
        self.label_map = json.load(file)

    self.id_to_class = {img_id: cls for cls, ids in self.label_map.items() for img_id in ids}

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_eurosat.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label_index = sample["label"]

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    class_names = list(self.label_map.keys())
    class_name = class_names[label_index]

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Class: {class_name}")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_bigearthnet

MBigEarthNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-BigEarthNet.

Source code in terratorch/datasets/m_bigearthnet.py
class MBigEarthNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-BigEarthNet](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "COASTAL_AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "WATER_VAPOR",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-bigearthnet"
    label_map_file = "label_stats.json"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        label_map_path = self.data_directory / self.label_map_file
        with open(label_map_path) as file:
            self.label_map = json.load(file)

        self.num_classes = len(next(iter(self.label_map.values())))

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found in partition file."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)

        labels_vector = self.label_map[image_id]
        labels_tensor = torch.tensor(labels_vector, dtype=torch.float)

        output = {"image": image}

        if self.transform:
            output = self.transform(**output)

        output["label"] = labels_tensor
        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label = sample["label"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # Convert to (H, W, C)

        rgb_image = image[:, :, rgb_indices]

        rgb_image = clip_image(rgb_image)

        active_labels = [i for i, lbl in enumerate(label) if lbl == 1]

        fig, ax = plt.subplots(figsize=(6, 6))

        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Active Labels: {active_labels}")

        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default')

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

Source code in terratorch/datasets/m_bigearthnet.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    label_map_path = self.data_directory / self.label_map_file
    with open(label_map_path) as file:
        self.label_map = json.load(file)

    self.num_classes = len(next(iter(self.label_map.values())))

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found in partition file."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_bigearthnet.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label = sample["label"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # Convert to (H, W, C)

    rgb_image = image[:, :, rgb_indices]

    rgb_image = clip_image(rgb_image)

    active_labels = [i for i, lbl in enumerate(label) if lbl == 1]

    fig, ax = plt.subplots(figsize=(6, 6))

    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Active Labels: {active_labels}")

    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_brick_kiln

MBrickKilnNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-BrickKiln.

Source code in terratorch/datasets/m_brick_kiln.py
class MBrickKilnNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-BrickKiln](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "COASTAL_AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "WATER_VAPOR",
        "CIRRUS",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-brick-kiln"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found in partition file."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))
            class_index = attr_dict["label"]

        output = {"image": image.astype(np.float32)}

        if self.transform:
            output = self.transform(**output)

        output["label"] = class_index

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label = sample["label"]

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # Convert to (H, W, C)

        rgb_image = image[:, :, rgb_indices]

        rgb_image = clip_image(rgb_image)

        fig, ax = plt.subplots(figsize=(6, 6))

        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Class: {label}")

        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default')

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

Source code in terratorch/datasets/m_brick_kiln.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found in partition file."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_brick_kiln.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label = sample["label"]

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # Convert to (H, W, C)

    rgb_image = image[:, :, rgb_indices]

    rgb_image = clip_image(rgb_image)

    fig, ax = plt.subplots(figsize=(6, 6))

    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Class: {label}")

    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_forestnet

MForestNetNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-ForestNet.

Source code in terratorch/datasets/m_forestnet.py
class MForestNetNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-ForestNet](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "BLUE",
        "GREEN",
        "RED",
        "NIR",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-forestnet"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
        use_metadata: bool = False,
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
            use_metadata (bool): Whether to return metadata info (time and location).
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))  # noqa: S301
            class_index = attr_dict["label"]

        output = {"image": image.astype(np.float32)}

        if self.transform:
            output = self.transform(**output)

        output["label"] = class_index

        if self.use_metadata:
            temporal_coords = self._get_date(image_id)
            location_coords = self._get_coords(image_id)

            output["temporal_coords"] = temporal_coords
            output["location_coords"] = location_coords

        return output

    def _get_coords(self, image_id: str) -> torch.Tensor:
        """Extract spatial coordinates from the image ID.

        Args:
            image_id (str): The ID of the image.

        Returns:
            torch.Tensor: Tensor containing latitude and longitude.
        """
        lat_str, lon_str, _ = image_id.split("_", 2)
        latitude = float(lat_str)
        longitude = float(lon_str)
        return torch.tensor([latitude, longitude], dtype=torch.float32)

    def _get_date(self, image_id: str) -> torch.Tensor:
        _, _, date_str = image_id.split("_", 2)
        date = pd.to_datetime(date_str, format="%Y_%m_%d")

        return torch.tensor([[date.year, date.dayofyear - 1]], dtype=torch.float32)

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label = sample["label"]

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # (H, W, C)

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        fig, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Class: {label}")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default', use_metadata=False)

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info (time and location).

Source code in terratorch/datasets/m_forestnet.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
    use_metadata: bool = False,
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
        use_metadata (bool): Whether to return metadata info (time and location).
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_forestnet.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label = sample["label"]

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # (H, W, C)

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Class: {label}")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_so2sat

MSo2SatNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-So2Sat.

Source code in terratorch/datasets/m_so2sat.py
class MSo2SatNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-So2Sat](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "VH_REAL",
        "BLUE",
        "VH_IMAGINARY",
        "GREEN",
        "VV_REAL",
        "RED",
        "VV_IMAGINARY",
        "VH_LEE_FILTERED",
        "RED_EDGE_1",
        "VV_LEE_FILTERED",
        "RED_EDGE_2",
        "VH_LEE_FILTERED_REAL",
        "RED_EDGE_3",
        "NIR_BROAD",
        "VV_LEE_FILTERED_IMAGINARY",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-so2sat"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))
            class_index = attr_dict["label"]

        output = {"image": image.astype(np.float32)}

        if self.transform:
            output = self.transform(**output)

        output["label"] = class_index

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label_index = sample["label"]

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        class_name = str(label_index)

        fig, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Class: {class_name}")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default')

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

Source code in terratorch/datasets/m_so2sat.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_so2sat.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label_index = sample["label"]

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    class_name = str(label_index)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Class: {class_name}")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_pv4ger

MPv4gerNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-PV4GER.

Source code in terratorch/datasets/m_pv4ger.py
class MPv4gerNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-PV4GER](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = ("BLUE", "GREEN", "RED")

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-pv4ger"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
        use_metadata: bool = False,
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
            use_metadata (bool): Whether to return metadata info (location coordinates).
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            attr_dict = pickle.loads(ast.literal_eval(h5file.attrs["pickle"]))  # noqa: S301
            class_index = attr_dict["label"]

        output = {"image": image.astype(np.float32)}

        if self.transform:
            output = self.transform(**output)

        output["label"] = class_index

        if self.use_metadata:
            output["location_coords"] = self._get_coords(image_id)

        return output

    def _get_coords(self, image_id: str) -> torch.Tensor:
        """Extract spatial coordinates from the image ID."""
        lat_str, lon_str = image_id.split(",")
        latitude = float(lat_str)
        longitude = float(lon_str)
        return torch.tensor([latitude, longitude], dtype=torch.float32)

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        label = sample["label"]

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        fig, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(rgb_image)
        ax.axis("off")
        ax.set_title(f"Class: {label}")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default', use_metadata=False)

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info (location coordinates).

Source code in terratorch/datasets/m_pv4ger.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
    use_metadata: bool = False,
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
        use_metadata (bool): Whether to return metadata info (location coordinates).
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_pv4ger.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    label = sample["label"]

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(rgb_image)
    ax.axis("off")
    ax.set_title(f"Class: {label}")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_cashew_plantation

MBeninSmallHolderCashewsNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-BeninSmallHolderCashews.

Source code in terratorch/datasets/m_cashew_plantation.py
class MBeninSmallHolderCashewsNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-BeninSmallHolderCashews](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "COASTAL_AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "WATER_VAPOR",
        "SWIR_1",
        "SWIR_2",
        "CLOUD_PROBABILITY",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-cashew-plant"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
        use_metadata: bool = False,
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
            use_metadata (bool): Whether to return metadata info (time).
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found in partition file."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def _get_date(self, keys) -> torch.Tensor:
        date_pattern = re.compile(r"\d{4}-\d{2}-\d{2}")

        date_str = None
        for key in keys:
            match = date_pattern.search(key)
            if match:
                date_str = match.group()
                break

        date = torch.zeros((1, 2), dtype=torch.float32)
        if date_str:
            date = pd.to_datetime(date_str, format="%Y-%m-%d")
            date = torch.tensor([[date.year, date.dayofyear - 1]], dtype=torch.float32)

        return date

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            temporal_coords = self._get_date(h5file)
            mask = np.array(h5file["label"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()
        if self.use_metadata:
            output["temporal_coords"] = temporal_coords

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # (H, W, C)

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default', use_metadata=False)

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info (time).

Source code in terratorch/datasets/m_cashew_plantation.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
    use_metadata: bool = False,
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
        use_metadata (bool): Whether to return metadata info (time).
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found in partition file."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_cashew_plantation.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # (H, W, C)

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_nz_cattle

MNzCattleNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-NZ-Cattle.

Source code in terratorch/datasets/m_nz_cattle.py
class MNzCattleNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-NZ-Cattle](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = ("BLUE", "GREEN", "RED")

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-nz-cattle"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
        use_metadata: bool = False,
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
            use_metadata (bool): Whether to return metadata info (time and location).
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        file_name = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())

            data_keys = [key for key in keys if "label" not in key]
            label_keys = [key for key in keys if "label" in key]

            temporal_coords = self._get_date(data_keys[0])

            bands = [np.array(h5file[key]) for key in data_keys]
            image = np.stack(bands, axis=-1)

            mask = np.array(h5file[label_keys[0]])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            location_coords = self._get_coords(file_name)
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def _get_coords(self, file_name: str) -> torch.Tensor:
        """Extract spatial coordinates from the file name."""
        match = re.search(r"_(\-?\d+\.\d+),(\-?\d+\.\d+)", file_name)
        if match:
            longitude, latitude = map(float, match.groups())

        return torch.tensor([latitude, longitude], dtype=torch.float32)

    def _get_date(self, band_name: str) -> torch.Tensor:
        date_str = band_name.split("_")[-1]
        date = pd.to_datetime(date_str, format="%Y-%m-%d")

        return torch.tensor([[date.year, date.dayofyear - 1]], dtype=torch.float32)

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default', use_metadata=False)

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info (time and location).

Source code in terratorch/datasets/m_nz_cattle.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
    use_metadata: bool = False,
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
        use_metadata (bool): Whether to return metadata info (time and location).
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_nz_cattle.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_chesapeake_landcover

MChesapeakeLandcoverNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-ChesapeakeLandcover.

Source code in terratorch/datasets/m_chesapeake_landcover.py
class MChesapeakeLandcoverNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-ChesapeakeLandcover](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = ("BLUE", "GREEN", "NIR", "RED")

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-chesapeake"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found in partition file."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            mask = np.array(h5file["label"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # (H, W, C)

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default')

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

Source code in terratorch/datasets/m_chesapeake_landcover.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found in partition file."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_chesapeake_landcover.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # (H, W, C)

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_pv4ger_seg

MPv4gerSegNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-PV4GER-SEG.

Source code in terratorch/datasets/m_pv4ger_seg.py
class MPv4gerSegNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-PV4GER-SEG](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = ("BLUE", "GREEN", "RED")

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-pv4ger-seg"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
        use_metadata: bool = False,
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
            use_metadata (bool): Whether to return metadata info (location coordinates).
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.use_metadata = use_metadata

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]
        image_id = file_path.stem

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            mask = np.array(h5file["label"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            output["location_coords"] = self._get_coords(image_id)

        return output

    def _get_coords(self, image_id: str) -> torch.Tensor:
        """Extract spatial coordinates from the image ID."""
        lat_str, lon_str = image_id.split(",")
        latitude = float(lat_str)
        longitude = float(lon_str)
        return torch.tensor([latitude, longitude], dtype=torch.float32)


    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default', use_metadata=False)

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info (location coordinates).

Source code in terratorch/datasets/m_pv4ger_seg.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
    use_metadata: bool = False,
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
        use_metadata (bool): Whether to return metadata info (location coordinates).
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.use_metadata = use_metadata

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_pv4ger_seg.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_SA_crop_type

MSACropTypeNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-SA-Crop-Type.

Source code in terratorch/datasets/m_SA_crop_type.py
class MSACropTypeNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-SA-Crop-Type](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = (
        "COASTAL_AEROSOL",
        "BLUE",
        "GREEN",
        "RED",
        "RED_EDGE_1",
        "RED_EDGE_2",
        "RED_EDGE_3",
        "NIR_BROAD",
        "NIR_NARROW",
        "WATER_VAPOR",
        "SWIR_1",
        "SWIR_2",
        "CLOUD_PROBABILITY",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-SA-crop-type"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = [self.all_band_names.index(b) for b in bands]

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            mask = np.array(h5file["label"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, partition='default')

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    Bands to be used. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

Source code in terratorch/datasets/m_SA_crop_type.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = [self.all_band_names.index(b) for b in bands]

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_SA_crop_type.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]

    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.m_neontree

MNeonTreeNonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for M-NeonTree.

Source code in terratorch/datasets/m_neontree.py
class MNeonTreeNonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [M-NeonTree](https://github.com/ServiceNow/geo-bench?tab=readme-ov-file)."""
    all_band_names = ("BLUE", "CANOPY_HEIGHT_MODEL", "GREEN", "NEON", "RED")

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    splits = {"train": "train", "val": "valid", "test": "test"}

    data_dir = "m-NeonTree"
    partition_file_template = "{partition}_partition.json"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = rgb_bands,
        transform: A.Compose | None = None,
        partition: str = "default",
    ) -> None:
        """Initialize the dataset.

        Args:
            data_root (str): Path to the data root directory.
            split (str): One of 'train', 'val', or 'test'.
            bands (Sequence[str]): Bands to be used. Defaults to RGB bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Defaults to None, which applies default_transform().
            partition (str): Partition name for the dataset splits. Defaults to 'default'.
        """
        super().__init__()

        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

        self.data_root = Path(data_root)
        self.data_directory = self.data_root / self.data_dir

        partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
        with open(partition_file) as file:
            partitions = json.load(file)

        if split_name not in partitions:
            msg = f"Split '{split_name}' not found."
            raise ValueError(msg)

        self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        file_path = self.image_files[index]

        with h5py.File(file_path, "r") as h5file:
            keys = sorted(h5file.keys())
            keys = np.array([key for key in keys if key != "label"])[self.band_indices]
            bands = [np.array(h5file[key]) for key in keys]

            image = np.stack(bands, axis=-1)
            mask = np.array(h5file["label"])

        output = {"image": image.astype(np.float32), "mask": mask}

        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        return output

    def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
        """Plot a sample from the dataset.

        Args:
            sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
            suptitle (str | None): Optional string to use as a suptitle.

        Returns:
            matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
        """
        rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)
        image = sample["image"]
        mask = sample["mask"].numpy()

        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()  # (H, W, C)

        rgb_image = image[:, :, rgb_indices]
        rgb_image = clip_image(rgb_image)

        num_classes = len(np.unique(mask))
        cmap = plt.get_cmap("jet")
        norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

        num_images = 4 if "prediction" in sample else 3
        fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

        ax[0].imshow(rgb_image)
        ax[0].set_title("Image")
        ax[0].axis("off")

        ax[1].imshow(mask, cmap=cmap, norm=norm)
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis("off")

        ax[2].imshow(rgb_image)
        ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
        ax[2].set_title("GT Mask on Image")
        ax[2].axis("off")

        if "prediction" in sample:
            prediction = sample["prediction"].numpy()
            ax[3].imshow(prediction, cmap=cmap, norm=norm)
            ax[3].set_title("Predicted Mask")
            ax[3].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=rgb_bands, transform=None, partition='default')

Initialize the dataset.

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    One of 'train', 'val', or 'test'.

  • bands (Sequence[str], default: rgb_bands ) –

    Bands to be used. Defaults to RGB bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Defaults to None, which applies default_transform().

  • partition (str, default: 'default' ) –

    Partition name for the dataset splits. Defaults to 'default'.

Source code in terratorch/datasets/m_neontree.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = rgb_bands,
    transform: A.Compose | None = None,
    partition: str = "default",
) -> None:
    """Initialize the dataset.

    Args:
        data_root (str): Path to the data root directory.
        split (str): One of 'train', 'val', or 'test'.
        bands (Sequence[str]): Bands to be used. Defaults to RGB bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Defaults to None, which applies default_transform().
        partition (str): Partition name for the dataset splits. Defaults to 'default'.
    """
    super().__init__()

    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.array([self.all_band_names.index(b) for b in bands])

    self.data_root = Path(data_root)
    self.data_directory = self.data_root / self.data_dir

    partition_file = self.data_directory / self.partition_file_template.format(partition=partition)
    with open(partition_file) as file:
        partitions = json.load(file)

    if split_name not in partitions:
        msg = f"Split '{split_name}' not found."
        raise ValueError(msg)

    self.image_files = [self.data_directory / f"{filename}.hdf5" for filename in partitions[split_name]]

    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    A sample returned by :meth:__getitem__.

  • suptitle (str | None, default: None ) –

    Optional string to use as a suptitle.

Returns:
  • Figure

    matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.

Source code in terratorch/datasets/m_neontree.py
def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure:
    """Plot a sample from the dataset.

    Args:
        sample (dict[str, torch.Tensor]): A sample returned by :meth:`__getitem__`.
        suptitle (str | None): Optional string to use as a suptitle.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure with the rendered sample.
    """
    rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)
    image = sample["image"]
    mask = sample["mask"].numpy()

    if torch.is_tensor(image):
        image = image.permute(1, 2, 0).numpy()  # (H, W, C)

    rgb_image = image[:, :, rgb_indices]
    rgb_image = clip_image(rgb_image)

    num_classes = len(np.unique(mask))
    cmap = plt.get_cmap("jet")
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    num_images = 4 if "prediction" in sample else 3
    fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True)

    ax[0].imshow(rgb_image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    ax[1].imshow(mask, cmap=cmap, norm=norm)
    ax[1].set_title("Ground Truth Mask")
    ax[1].axis("off")

    ax[2].imshow(rgb_image)
    ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm)
    ax[2].set_title("GT Mask on Image")
    ax[2].axis("off")

    if "prediction" in sample:
        prediction = sample["prediction"].numpy()
        ax[3].imshow(prediction, cmap=cmap, norm=norm)
        ax[3].set_title("Predicted Mask")
        ax[3].axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.multi_temporal_crop_classification

MultiTemporalCropClassification

Bases: NonGeoDataset

NonGeo dataset implementation for multi-temporal crop classification.

Source code in terratorch/datasets/multi_temporal_crop_classification.py
class MultiTemporalCropClassification(NonGeoDataset):
    """NonGeo dataset implementation for [multi-temporal crop classification](https://huggingface.co/datasets/ibm-nasa-geospatial/multi-temporal-crop-classification)."""

    all_band_names = (
        "BLUE",
        "GREEN",
        "RED",
        "NIR_NARROW",
        "SWIR_1",
        "SWIR_2",
    )

    class_names = (
        "Natural Vegetation",
        "Forest",
        "Corn",
        "Soybeans",
        "Wetlands",
        "Developed / Barren",
        "Open Water",
        "Winter Wheat",
        "Alfalfa",
        "Fallow / Idle Cropland",
        "Cotton",
        "Sorghum",
        "Other",
    )

    rgb_bands = ("RED", "GREEN", "BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    num_classes = 13
    time_steps = 3
    splits = {"train": "training", "val": "validation"}  # Only train and val splits available
    metadata_file_name = "chip_df_final.csv"
    col_name = "chip_id"
    date_columns = ["first_img_date", "middle_img_date", "last_img_date"]

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        no_data_replace: float | None = None,
        no_label_replace: int | None = None,
        expand_temporal_dimension: bool = True,
        reduce_zero_label: bool = True,
        use_metadata: bool = False,
    ) -> None:
        """Constructor

        Args:
            data_root (str): Path to the data root directory.
            split (str): one of 'train' or 'val'.
            bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). If used through the corresponding data module,
                should not include normalization. Defaults to None, which applies ToTensorV2().
            no_data_replace (float | None): Replace nan values in input images with this value.
                If None, does no replacement. Defaults to None.
            no_label_replace (int | None): Replace nan values in label with this value.
                If none, does no replacement. Defaults to None.
            expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to True.
            reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to True.
            use_metadata (bool): whether to return metadata info (time and location).
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {self.splits}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
        self.data_root = Path(data_root)

        data_dir = self.data_root / f"{split_name}_chips"
        self.image_files = sorted(glob.glob(os.path.join(data_dir, "*_merged.tif")))
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(data_dir, "*.mask.tif")))
        split_file = data_dir / f"{split_name}_data.txt"

        with open(split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=True,
            allow_substring=True,
        )
        self.segmentation_mask_files = filter_valid_files(
            self.segmentation_mask_files,
            valid_files=valid_files,
            ignore_extensions=True,
            allow_substring=True,
        )

        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.reduce_zero_label = reduce_zero_label
        self.expand_temporal_dimension = expand_temporal_dimension
        self.use_metadata = use_metadata
        self.metadata = None
        if self.use_metadata:
            metadata_file = self.data_root / self.metadata_file_name
            self.metadata = pd.read_csv(metadata_file)
            self._build_image_metadata_mapping()

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform

    def _build_image_metadata_mapping(self):
        """Build a mapping from image filenames to metadata indices."""
        self.image_to_metadata_index = dict()

        for idx, image_file in enumerate(self.image_files):
            image_filename = Path(image_file).name
            image_id = image_filename.replace("_merged.tif", "").replace(".tif", "")
            metadata_indices = self.metadata.index[self.metadata[self.col_name] == image_id].tolist()
            self.image_to_metadata_index[idx] = metadata_indices[0]

    def __len__(self) -> int:
        return len(self.image_files)

    def _get_date(self, row: pd.Series) -> torch.Tensor:
        """Extract and format temporal coordinates (T, date) from metadata."""
        temporal_coords = []
        for col in self.date_columns:
            date_str = row[col]
            date = pd.to_datetime(date_str, format="%Y-%m-%d")
            temporal_coords.append([date.year, date.dayofyear - 1])

        return torch.tensor(temporal_coords, dtype=torch.float32)

    def _get_coords(self, image: DataArray) -> torch.Tensor:
        px = image.x.shape[0] // 2
        py = image.y.shape[0] // 2

        # get center point to reproject to lat/lon
        point = image.isel(band=0, x=slice(px, px + 1), y=slice(py, py + 1))
        point = point.rio.reproject("epsg:4326")

        lat_lon = np.asarray([point.y[0], point.x[0]])

        return torch.tensor(lat_lon, dtype=torch.float32)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace)

        location_coords, temporal_coords = None, None
        if self.use_metadata:
            location_coords = self._get_coords(image)
            metadata_idx = self.image_to_metadata_index.get(index, None)
            if metadata_idx is not None:
                row = self.metadata.iloc[metadata_idx]
                temporal_coords = self._get_date(row)

        # to channels last
        image = image.to_numpy()
        if self.expand_temporal_dimension:
            image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.bands))
        image = np.moveaxis(image, 0, -1)

        # filter bands
        image = image[..., self.band_indices]

        output = {
            "image": image.astype(np.float32),
            "mask": self._load_file(
                self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[0],
        }

        if self.reduce_zero_label:
            output["mask"] -= 1
        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def _load_file(self, path: Path, nan_replace: int | float | None = None) -> DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample
        """
        num_images = self.time_steps + 2

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        images = sample["image"]
        images = images[rgb_indices, ...]  # Shape: (T, 3, H, W)

        processed_images = []
        for t in range(self.time_steps):
            img = images[t]
            img = img.permute(1, 2, 0)
            img = img.numpy()
            img = clip_image(img)
            processed_images.append(img)

        mask = sample["mask"].numpy()
        if "prediction" in sample:
            num_images += 1
        fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")
        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
        for i, img in enumerate(processed_images):
            ax[i + 1].axis("off")
            ax[i + 1].title.set_text(f"T{i}")
            ax[i + 1].imshow(img)

        ax[self.time_steps + 1].axis("off")
        ax[self.time_steps + 1].title.set_text("Ground Truth Mask")
        ax[self.time_steps + 1].imshow(mask, cmap="jet", norm=norm)

        if "prediction" in sample:
            prediction = sample["prediction"]
            ax[self.time_steps + 1].axis("off")
            ax[self.time_steps+2].title.set_text("Predicted Mask")
            ax[self.time_steps+2].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = [[i, cmap(norm(i)), self.class_names[i]] for i in range(self.num_classes)]
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")

        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, no_data_replace=None, no_label_replace=None, expand_temporal_dimension=True, reduce_zero_label=True, use_metadata=False)

Constructor

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    one of 'train' or 'val'.

  • bands (list[str], default: BAND_SETS['all'] ) –

    Bands that should be output by the dataset. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). If used through the corresponding data module, should not include normalization. Defaults to None, which applies ToTensorV2().

  • no_data_replace (float | None, default: None ) –

    Replace nan values in input images with this value. If None, does no replacement. Defaults to None.

  • no_label_replace (int | None, default: None ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to None.

  • expand_temporal_dimension (bool, default: True ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to True.

  • reduce_zero_label (bool, default: True ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to True.

  • use_metadata (bool, default: False ) –

    whether to return metadata info (time and location).

Source code in terratorch/datasets/multi_temporal_crop_classification.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    no_data_replace: float | None = None,
    no_label_replace: int | None = None,
    expand_temporal_dimension: bool = True,
    reduce_zero_label: bool = True,
    use_metadata: bool = False,
) -> None:
    """Constructor

    Args:
        data_root (str): Path to the data root directory.
        split (str): one of 'train' or 'val'.
        bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). If used through the corresponding data module,
            should not include normalization. Defaults to None, which applies ToTensorV2().
        no_data_replace (float | None): Replace nan values in input images with this value.
            If None, does no replacement. Defaults to None.
        no_label_replace (int | None): Replace nan values in label with this value.
            If none, does no replacement. Defaults to None.
        expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to True.
        reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to True.
        use_metadata (bool): whether to return metadata info (time and location).
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {self.splits}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
    self.data_root = Path(data_root)

    data_dir = self.data_root / f"{split_name}_chips"
    self.image_files = sorted(glob.glob(os.path.join(data_dir, "*_merged.tif")))
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(data_dir, "*.mask.tif")))
    split_file = data_dir / f"{split_name}_data.txt"

    with open(split_file) as f:
        split = f.readlines()
    valid_files = {rf"{substring.strip()}" for substring in split}
    self.image_files = filter_valid_files(
        self.image_files,
        valid_files=valid_files,
        ignore_extensions=True,
        allow_substring=True,
    )
    self.segmentation_mask_files = filter_valid_files(
        self.segmentation_mask_files,
        valid_files=valid_files,
        ignore_extensions=True,
        allow_substring=True,
    )

    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.reduce_zero_label = reduce_zero_label
    self.expand_temporal_dimension = expand_temporal_dimension
    self.use_metadata = use_metadata
    self.metadata = None
    if self.use_metadata:
        metadata_file = self.data_root / self.metadata_file_name
        self.metadata = pd.read_csv(metadata_file)
        self._build_image_metadata_mapping()

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • suptitle (str | None, default: None ) –

    optional string to use as a suptitle

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/multi_temporal_crop_classification.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample
    """
    num_images = self.time_steps + 2

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    images = sample["image"]
    images = images[rgb_indices, ...]  # Shape: (T, 3, H, W)

    processed_images = []
    for t in range(self.time_steps):
        img = images[t]
        img = img.permute(1, 2, 0)
        img = img.numpy()
        img = clip_image(img)
        processed_images.append(img)

    mask = sample["mask"].numpy()
    if "prediction" in sample:
        num_images += 1
    fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")
    ax[0].axis("off")

    norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
    for i, img in enumerate(processed_images):
        ax[i + 1].axis("off")
        ax[i + 1].title.set_text(f"T{i}")
        ax[i + 1].imshow(img)

    ax[self.time_steps + 1].axis("off")
    ax[self.time_steps + 1].title.set_text("Ground Truth Mask")
    ax[self.time_steps + 1].imshow(mask, cmap="jet", norm=norm)

    if "prediction" in sample:
        prediction = sample["prediction"]
        ax[self.time_steps + 1].axis("off")
        ax[self.time_steps+2].title.set_text("Predicted Mask")
        ax[self.time_steps+2].imshow(prediction, cmap="jet", norm=norm)

    cmap = plt.get_cmap("jet")
    legend_data = [[i, cmap(norm(i)), self.class_names[i]] for i in range(self.num_classes)]
    handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
    labels = [n for k, c, n in legend_data]
    ax[0].legend(handles, labels, loc="center")

    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.open_sentinel_map

OpenSentinelMap

Bases: NonGeoDataset

Pytorch Dataset class to load samples from the OpenSentinelMap dataset, supporting multiple bands and temporal sampling strategies.

Source code in terratorch/datasets/open_sentinel_map.py
class OpenSentinelMap(NonGeoDataset):
    """
        Pytorch Dataset class to load samples from the [OpenSentinelMap](https://visionsystemsinc.github.io/open-sentinel-map/) dataset, supporting
        multiple bands and temporal sampling strategies.
    """
    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: list[str] | None = None,
        transform: A.Compose | None = None,
        spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
        pad_image: int | None = None,
        truncate_image: int | None = None,
        target: int = 0,
        pick_random_pair: bool = True,  # noqa: FBT002, FBT001
    ) -> None:
        """

        Args:
            data_root (str): Path to the root directory of the dataset.
            split (str): Dataset split to load. Options are 'train', 'val', or 'test'. Defaults to 'train'.
            bands (list of str, optional): List of band names to load. Defaults to ['gsd_10', 'gsd_20', 'gsd_60'].
            transform (albumentations.Compose, optional): Albumentations transformations to apply to the data.
            spatial_interpolate_and_stack_temporally (bool): If True, the bands are interpolated and concatenated over time.
                Default is True.
            pad_image (int, optional): Number of timesteps to pad the time dimension of the image.
                If None, no padding is applied.
            truncate_image (int, optional): Number of timesteps to truncate the time dimension of the image.
                If None, no truncation is performed.
            target (int): Specifies which target class to use from the mask. Default is 0.
            pick_random_pair (bool): If True, selects two random images from the temporal sequence. Default is True.
        """
        split = "test"
        if bands is None:
            bands = ["gsd_10", "gsd_20", "gsd_60"]

        allowed_bands = {"gsd_10", "gsd_20", "gsd_60"}
        for band in bands:
            if band not in allowed_bands:
                msg = f"Band '{band}' is not recognized. Available values are: {', '.join(allowed_bands)}"
                raise ValueError(msg)

        if split not in ["train", "val", "test"]:
            msg = f"Split '{split}' not recognized. Use 'train', 'val', or 'test'."
            raise ValueError(msg)

        self.data_root = Path(data_root)
        split_mapping = {"train": "training", "val": "validation", "test": "testing"}
        split = split_mapping[split]
        self.imagery_root = self.data_root / "osm_sentinel_imagery"
        self.label_root = self.data_root / "osm_label_images_v10"
        self.auxiliary_data = pd.read_csv(self.data_root / "spatial_cell_info.csv")
        self.auxiliary_data = self.auxiliary_data[self.auxiliary_data["split"] == split]
        self.bands = bands
        self.transform = transform if transform else lambda **batch: to_tensor(batch)
        self.label_mappings = self._load_label_mappings()
        self.split_data = self.auxiliary_data[self.auxiliary_data["split"] == split]
        self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally
        self.pad_image = pad_image
        self.truncate_image = truncate_image
        self.target = target
        self.pick_random_pair = pick_random_pair

        self.image_files = []
        self.label_files = []

        for _, row in self.split_data.iterrows():
            mgrs_tile = row["MGRS_tile"]
            spatial_cell = str(row["cell_id"])

            label_file = self.label_root / mgrs_tile / f"{spatial_cell}.png"

            if label_file.exists():
                self.image_files.append((mgrs_tile, spatial_cell))
                self.label_files.append(label_file)

    def _load_label_mappings(self):
        with open(self.data_root / "osm_categories.json") as f:
            return json.load(f)

    def _extract_date_from_filename(self, filename: str) -> str:
        match = re.search(r"(\d{8})", filename)
        if match:
            return match.group(1)
        else:
            msg = f"Date not found in filename {filename}"
            raise ValueError(msg)

    def __len__(self) -> int:
        return len(self.image_files)

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        if "gsd_10" not in self.bands:
            return None

        num_images = len([key for key in sample if key.startswith("image")])
        images = []

        for i in range(1, num_images + 1):
            image_dict = sample[f"image{i}"]
            image = image_dict["gsd_10"]
            if isinstance(image, Tensor):
                image = image.numpy()

            image = image.take(range(3), axis=2)
            image = image.squeeze()
            image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1)))
            image = np.clip(image, 0, 1)
            images.append(image)

        label_mask = sample["mask"]
        if isinstance(label_mask, Tensor):
            label_mask = label_mask.numpy()

        return self._plot_sample(images, label_mask, suptitle=suptitle)


    def _plot_sample(
        self, images: list[np.ndarray],
        label: np.ndarray,
        suptitle: str | None = None,
    ) -> Figure:
        num_images = len(images)
        fig, ax = plt.subplots(1, num_images + 1, figsize=(15, 5))

        for i, image in enumerate(images):
            ax[i].imshow(image)
            ax[i].set_title(f"Image {i + 1}")
            ax[i].axis("off")

        ax[-1].imshow(label, cmap="gray")
        ax[-1].set_title("Ground Truth Mask")
        ax[-1].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        return fig

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        mgrs_tile, spatial_cell = self.image_files[index]
        spatial_cell_path = self.imagery_root / mgrs_tile / spatial_cell

        npz_files = list(spatial_cell_path.glob("*.npz"))
        npz_files.sort(key=lambda x: self._extract_date_from_filename(x.stem))

        if self.pick_random_pair:
            npz_files = random.sample(npz_files, 2)
            npz_files.sort(key=lambda x: self._extract_date_from_filename(x.stem))

        output = {}

        if self.spatial_interpolate_and_stack_temporally:
            images_over_time = []
            for _, npz_file in enumerate(npz_files):
                data = np.load(npz_file)
                interpolated_bands = []
                for band in self.bands:
                    band_frame = data[band]
                    band_frame = torch.from_numpy(band_frame).float()
                    band_frame = band_frame.permute(2, 0, 1)
                    interpolated = F.interpolate(
                        band_frame.unsqueeze(0), size=MAX_TEMPORAL_IMAGE_SIZE, mode="bilinear", align_corners=False
                    ).squeeze(0)
                    interpolated_bands.append(interpolated)
                concatenated_bands = torch.cat(interpolated_bands, dim=0)
                images_over_time.append(concatenated_bands)

            images = torch.stack(images_over_time, dim=0).numpy()
            if self.truncate_image:
                images = images[-self.truncate_image:]
            if self.pad_image:
                images = pad_numpy(images, self.pad_image)

            output["image"] = images.transpose(0, 2, 3, 1)
        else:
            image_dict = {band: [] for band in self.bands}
            for _, npz_file in enumerate(npz_files):
                data = np.load(npz_file)
                for band in self.bands:
                    band_frames = data[band]
                    band_frames = band_frames.astype(np.float32)
                    band_frames = np.transpose(band_frames, (2, 0, 1))
                    image_dict[band].append(band_frames)

            final_image_dict = {}
            for band in self.bands:
                band_images = image_dict[band]
                if self.truncate_image:
                    band_images = band_images[-self.truncate_image:]
                if self.pad_image:
                    band_images = [pad_numpy(img, self.pad_image) for img in band_images]
                band_images = np.stack(band_images, axis=0)
                final_image_dict[band] = band_images

            output["image"] = final_image_dict

        label_file = self.label_files[index]
        mask = np.array(Image.open(label_file)).astype(int)

        # Map 'unlabel' (254) and 'none' (255) to unused classes 15 and 16 for processing
        mask[mask == 254] = 15  # noqa: PLR2004
        mask[mask == 255] = 16  # noqa: PLR2004
        output["mask"] = mask[:, :, self.target]

        if self.transform:
            output = self.transform(**output)

        return output
__init__(data_root, split='train', bands=None, transform=None, spatial_interpolate_and_stack_temporally=True, pad_image=None, truncate_image=None, target=0, pick_random_pair=True)
Parameters:
  • data_root (str) –

    Path to the root directory of the dataset.

  • split (str, default: 'train' ) –

    Dataset split to load. Options are 'train', 'val', or 'test'. Defaults to 'train'.

  • bands (list of str, default: None ) –

    List of band names to load. Defaults to ['gsd_10', 'gsd_20', 'gsd_60'].

  • transform (Compose, default: None ) –

    Albumentations transformations to apply to the data.

  • spatial_interpolate_and_stack_temporally (bool, default: True ) –

    If True, the bands are interpolated and concatenated over time. Default is True.

  • pad_image (int, default: None ) –

    Number of timesteps to pad the time dimension of the image. If None, no padding is applied.

  • truncate_image (int, default: None ) –

    Number of timesteps to truncate the time dimension of the image. If None, no truncation is performed.

  • target (int, default: 0 ) –

    Specifies which target class to use from the mask. Default is 0.

  • pick_random_pair (bool, default: True ) –

    If True, selects two random images from the temporal sequence. Default is True.

Source code in terratorch/datasets/open_sentinel_map.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: list[str] | None = None,
    transform: A.Compose | None = None,
    spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
    pad_image: int | None = None,
    truncate_image: int | None = None,
    target: int = 0,
    pick_random_pair: bool = True,  # noqa: FBT002, FBT001
) -> None:
    """

    Args:
        data_root (str): Path to the root directory of the dataset.
        split (str): Dataset split to load. Options are 'train', 'val', or 'test'. Defaults to 'train'.
        bands (list of str, optional): List of band names to load. Defaults to ['gsd_10', 'gsd_20', 'gsd_60'].
        transform (albumentations.Compose, optional): Albumentations transformations to apply to the data.
        spatial_interpolate_and_stack_temporally (bool): If True, the bands are interpolated and concatenated over time.
            Default is True.
        pad_image (int, optional): Number of timesteps to pad the time dimension of the image.
            If None, no padding is applied.
        truncate_image (int, optional): Number of timesteps to truncate the time dimension of the image.
            If None, no truncation is performed.
        target (int): Specifies which target class to use from the mask. Default is 0.
        pick_random_pair (bool): If True, selects two random images from the temporal sequence. Default is True.
    """
    split = "test"
    if bands is None:
        bands = ["gsd_10", "gsd_20", "gsd_60"]

    allowed_bands = {"gsd_10", "gsd_20", "gsd_60"}
    for band in bands:
        if band not in allowed_bands:
            msg = f"Band '{band}' is not recognized. Available values are: {', '.join(allowed_bands)}"
            raise ValueError(msg)

    if split not in ["train", "val", "test"]:
        msg = f"Split '{split}' not recognized. Use 'train', 'val', or 'test'."
        raise ValueError(msg)

    self.data_root = Path(data_root)
    split_mapping = {"train": "training", "val": "validation", "test": "testing"}
    split = split_mapping[split]
    self.imagery_root = self.data_root / "osm_sentinel_imagery"
    self.label_root = self.data_root / "osm_label_images_v10"
    self.auxiliary_data = pd.read_csv(self.data_root / "spatial_cell_info.csv")
    self.auxiliary_data = self.auxiliary_data[self.auxiliary_data["split"] == split]
    self.bands = bands
    self.transform = transform if transform else lambda **batch: to_tensor(batch)
    self.label_mappings = self._load_label_mappings()
    self.split_data = self.auxiliary_data[self.auxiliary_data["split"] == split]
    self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally
    self.pad_image = pad_image
    self.truncate_image = truncate_image
    self.target = target
    self.pick_random_pair = pick_random_pair

    self.image_files = []
    self.label_files = []

    for _, row in self.split_data.iterrows():
        mgrs_tile = row["MGRS_tile"]
        spatial_cell = str(row["cell_id"])

        label_file = self.label_root / mgrs_tile / f"{spatial_cell}.png"

        if label_file.exists():
            self.image_files.append((mgrs_tile, spatial_cell))
            self.label_files.append(label_file)

terratorch.datasets.openearthmap

OpenEarthMapNonGeo

Bases: NonGeoDataset

OpenEarthMapNonGeo Dataset for non-georeferenced imagery.

This dataset class handles non-georeferenced image data from the OpenEarthMap dataset. It supports configurable band sets and transformations, and performs cropping operations to ensure that the images conform to the required input dimensions. The dataset is split into "train", "test", and "val" subsets based on the provided split parameter.

Source code in terratorch/datasets/openearthmap.py
class OpenEarthMapNonGeo(NonGeoDataset):
    """
    [OpenEarthMapNonGeo](https://open-earth-map.org/) Dataset for non-georeferenced imagery.

    This dataset class handles non-georeferenced image data from the OpenEarthMap dataset.
    It supports configurable band sets and transformations, and performs cropping operations
    to ensure that the images conform to the required input dimensions. The dataset is split
    into "train", "test", and "val" subsets based on the provided split parameter.
    """


    all_band_names = ("BLUE","GREEN","RED")

    rgb_bands = ("RED","GREEN","BLUE")

    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}

    def __init__(self, data_root: str,
                 bands: Sequence[str] = BAND_SETS["all"],
                 transform: A.Compose | None = None,
                 split="train",
                 crop_size: int = 256,
                 random_crop: bool = True) -> None:
        """
        Initialize a new instance of the OpenEarthMapNonGeo dataset.

        Args:
            data_root (str): The root directory containing the dataset files.
            bands (Sequence[str], optional): A list of band names to be used. Default is BAND_SETS["all"].
            transform (A.Compose or None, optional): A transformation pipeline to be applied to the data.
                If None, a default transform converting the data to a tensor is applied.
            split (str, optional): The dataset split to use ("train", "test", or "val"). Default is "train".
            crop_size (int, optional): The size (in pixels) of the crop to apply to images. Must be greater than 0.
                Default is 256.
            random_crop (bool, optional): If True, performs a random crop; otherwise, performs a center crop.
                Default is True.

        Raises:
            Exception: If the provided split is not one of "train", "test", or "val".
            AssertionError: If crop_size is not greater than 0.
        """
        super().__init__()
        if split not in ["train", "test", "val"]:
            msg = "Split must be one of train, test, val."
            raise Exception(msg)

        self.transform = transform if transform else lambda **batch: to_tensor(batch, transpose=False)
        self.split = split
        self.data_root = data_root

        # images in openearthmap are not all 1024x1024 and must be cropped
        self.crop_size = crop_size
        self.random_crop = random_crop

        assert self.crop_size > 0, "Crop size must be greater than 0"

        self.image_files = self._get_file_paths(Path(self.data_root, f"{split}.txt"))

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        image_path, label_path = self.image_files[index]

        with rasterio.open(image_path) as src:
            image = src.read()
        with rasterio.open(label_path) as src:
            mask = src.read()

        # some images in the dataset are not perfect squares
        # cropping to fit to the prepare_features_for_image_model call
        if self.random_crop:
            image, mask = self._random_crop(image, mask)
        else:
            image, mask = self._center_crop(image, mask)

        output =  {
            "image": image.astype(np.float32),
            "mask": mask
        }

        output = self.transform(**output)
        output['mask'] = output['mask'].long()

        return output

    def _parse_file_name(self, file_name: str):
        underscore_pos = file_name.rfind('_')
        folder_name = file_name[:underscore_pos]
        region_path = Path(self.data_root, folder_name)
        image_path = Path(region_path, "images", file_name)
        label_path = Path(region_path, "labels", file_name)
        return image_path, label_path

    def _get_file_paths(self, text_file_path: str):
        with open(text_file_path, 'r') as file:
            lines = file.readlines()
            file_paths = [self._parse_file_name(line.strip()) for line in lines]
        return file_paths

    def __len__(self):
        return len(self.image_files)

    def _random_crop(self, image, mask):
        h, w = image.shape[1:]
        top = np.random.randint(0, h - self.crop_size)
        left = np.random.randint(0, w - self.crop_size)

        image = image[:, top: top + self.crop_size, left: left + self.crop_size]
        mask = mask[:, top: top + self.crop_size, left: left + self.crop_size]

        return image, mask

    def _center_crop(self, image, mask):
        h, w = image.shape[1:]
        top = (h - self.crop_size) // 2
        left = (w - self.crop_size) // 2

        image = image[:, top: top + self.crop_size, left: left + self.crop_size]
        mask = mask[:, top: top + self.crop_size, left: left + self.crop_size]

        return image, mask

    def plot(self, arg, suptitle: str | None = None) -> None:
        pass

    def plot_sample(self, sample, prediction=None, suptitle: str | None = None, class_names=None):
        pass
__init__(data_root, bands=BAND_SETS['all'], transform=None, split='train', crop_size=256, random_crop=True)

Initialize a new instance of the OpenEarthMapNonGeo dataset.

Parameters:
  • data_root (str) –

    The root directory containing the dataset files.

  • bands (Sequence[str], default: BAND_SETS['all'] ) –

    A list of band names to be used. Default is BAND_SETS["all"].

  • transform (Compose or None, default: None ) –

    A transformation pipeline to be applied to the data. If None, a default transform converting the data to a tensor is applied.

  • split (str, default: 'train' ) –

    The dataset split to use ("train", "test", or "val"). Default is "train".

  • crop_size (int, default: 256 ) –

    The size (in pixels) of the crop to apply to images. Must be greater than 0. Default is 256.

  • random_crop (bool, default: True ) –

    If True, performs a random crop; otherwise, performs a center crop. Default is True.

Raises:
  • Exception

    If the provided split is not one of "train", "test", or "val".

  • AssertionError

    If crop_size is not greater than 0.

Source code in terratorch/datasets/openearthmap.py
def __init__(self, data_root: str,
             bands: Sequence[str] = BAND_SETS["all"],
             transform: A.Compose | None = None,
             split="train",
             crop_size: int = 256,
             random_crop: bool = True) -> None:
    """
    Initialize a new instance of the OpenEarthMapNonGeo dataset.

    Args:
        data_root (str): The root directory containing the dataset files.
        bands (Sequence[str], optional): A list of band names to be used. Default is BAND_SETS["all"].
        transform (A.Compose or None, optional): A transformation pipeline to be applied to the data.
            If None, a default transform converting the data to a tensor is applied.
        split (str, optional): The dataset split to use ("train", "test", or "val"). Default is "train".
        crop_size (int, optional): The size (in pixels) of the crop to apply to images. Must be greater than 0.
            Default is 256.
        random_crop (bool, optional): If True, performs a random crop; otherwise, performs a center crop.
            Default is True.

    Raises:
        Exception: If the provided split is not one of "train", "test", or "val".
        AssertionError: If crop_size is not greater than 0.
    """
    super().__init__()
    if split not in ["train", "test", "val"]:
        msg = "Split must be one of train, test, val."
        raise Exception(msg)

    self.transform = transform if transform else lambda **batch: to_tensor(batch, transpose=False)
    self.split = split
    self.data_root = data_root

    # images in openearthmap are not all 1024x1024 and must be cropped
    self.crop_size = crop_size
    self.random_crop = random_crop

    assert self.crop_size > 0, "Crop size must be greater than 0"

    self.image_files = self._get_file_paths(Path(self.data_root, f"{split}.txt"))

terratorch.datasets.pastis

PASTIS

Bases: NonGeoDataset

" Pytorch Dataset class to load samples from the PASTIS dataset, for semantic and panoptic segmentation.

Source code in terratorch/datasets/pastis.py
class PASTIS(NonGeoDataset):
    """"
        Pytorch Dataset class to load samples from the [PASTIS](https://github.com/VSainteuf/pastis-benchmark) dataset,
        for semantic and panoptic segmentation.
    """
    def __init__(
        self,
        data_root,
        norm=True,  # noqa: FBT002
        target="semantic",
        folds=None,
        reference_date="2018-09-01",
        date_interval = (-200,600),
        class_mapping=None,
        transform = None,
        truncate_image = None,
        pad_image = None,
        satellites=["S2"],  # noqa: B006
    ):
        """

        Args:
            data_root (str): Path to the dataset.
            norm (bool): If true, images are standardised using pre-computed
                channel-wise means and standard deviations.
            reference_date (str, Format : 'YYYY-MM-DD'): Defines the reference date
                based on which all observation dates are expressed. Along with the image
                time series and the target tensor, this dataloader yields the sequence
                of observation dates (in terms of number of days since the reference
                date). This sequence of dates is used for instance for the positional
                encoding in attention based approaches.
            target (str): 'semantic' or 'instance'. Defines which type of target is
                returned by the dataloader.
                * If 'semantic' the target tensor is a tensor containing the class of
                each pixel.
                * If 'instance' the target tensor is the concatenation of several
                signals, necessary to train the Parcel-as-Points module:
                    - the centerness heatmap,
                    - the instance ids,
                    - the voronoi partitioning of the patch with regards to the parcels'
                    centers,
                    - the (height, width) size of each parcel,
                    - the semantic label of each parcel,
                    - the semantic label of each pixel.
            folds (list, optional): List of ints specifying which of the 5 official
                folds to load. By default (when None is specified), all folds are loaded.
            class_mapping (dict, optional): A dictionary to define a mapping between the
                default 18 class nomenclature and another class grouping. If not provided, 
                the default class mapping is used.
            transform (callable, optional): A transform to apply to the loaded data 
                (images, dates, and masks). By default, no transformation is applied.
            truncate_image (int, optional): Truncate the time dimension of the image to 
                a specified number of timesteps. If None, no truncation is performed.
            pad_image (int, optional): Pad the time dimension of the image to a specified 
                number of timesteps. If None, no padding is applied.
            satellites (list): Defines the satellites to use. If you are using PASTIS-R, you
                have access to Sentinel-2 imagery and Sentinel-1 observations in Ascending
                and Descending orbits, respectively S2, S1A, and S1D. For example, use
                satellites=['S2', 'S1A'] for Sentinel-2 + Sentinel-1 ascending time series,
                or satellites=['S2', 'S1A', 'S1D'] to retrieve all time series. If you are using
                PASTIS, only S2 observations are available.
        """
        if target not in ["semantic", "instance"]:
            msg = f"Target '{target}' not recognized. Use 'semantic', or 'instance'."
            raise ValueError(msg)
        valid_satellites = {"S2", "S1A", "S1D"}
        for sat in satellites:
            if sat not in valid_satellites:
                msg = f"Satellite '{sat}' not recognized. Valid options are {valid_satellites}."
                raise ValueError(msg)

        super().__init__()
        self.data_root = data_root
        self.norm = norm
        self.reference_date = datetime(*map(int, reference_date.split("-")), tzinfo=timezone.utc)
        self.class_mapping = (
            np.vectorize(lambda x: class_mapping[x])
            if class_mapping is not None
            else class_mapping
        )
        self.target = target
        self.satellites = satellites
        self.transform = transform
        self.truncate_image = truncate_image
        self.pad_image = pad_image
        # loads patches metadata
        self.meta_patch = gpd.read_file(os.path.join(data_root, "metadata.geojson"))
        self.meta_patch.index = self.meta_patch["ID_PATCH"].astype(int)
        self.meta_patch.sort_index(inplace=True)
        # stores table for each satalite date
        self.date_tables = {s: None for s in satellites}
        # date interval used in the PASTIS benchmark paper.
        date_interval_begin, date_interval_end = date_interval
        self.date_range = np.array(range(date_interval_begin, date_interval_end))
        for s in satellites:
            # maps patches to its observation dates
            dates = self.meta_patch[f"dates-{s}"]
            date_table = pd.DataFrame(
                index=self.meta_patch.index, columns=self.date_range, dtype=int
            )
            for pid, date_seq in dates.items():
                if type(date_seq) is str:
                    date_seq = json.loads(date_seq)  # noqa: PLW2901
                # convert date to days since obersavation format
                d = pd.DataFrame().from_dict(date_seq, orient="index")
                d = d[0].apply(
                    lambda x: (
                        datetime(int(str(x)[:4]), int(str(x)[4:6]), int(str(x)[6:]), tzinfo=timezone.utc)
                        - self.reference_date
                    ).days
                )
                date_table.loc[pid, d.values] = 1
            date_table = date_table.fillna(0)
            self.date_tables[s] = {
                index: np.array(list(d.values()))
                for index, d in date_table.to_dict(orient="index").items()
            }

        # selects patches correspondig to selected folds
        if folds is not None:
            self.meta_patch = pd.concat(
                [self.meta_patch[self.meta_patch["Fold"] == f] for f in folds]
            )

        self.len = self.meta_patch.shape[0]
        self.id_patches = self.meta_patch.index

        # loads normalization values
        if norm:
            self.norm = {}
            for s in self.satellites:
                with open(
                    os.path.join(data_root, f"NORM_{s}_patch.json")
                ) as file:
                    normvals = json.loads(file.read())
                selected_folds = folds if folds is not None else range(1, 6)
                means = [normvals[f"Fold_{f}"]["mean"] for f in selected_folds]
                stds = [normvals[f"Fold_{f}"]["std"] for f in selected_folds]
                self.norm[s] = np.stack(means).mean(axis=0), np.stack(stds).mean(axis=0)
                self.norm[s] = (
                    self.norm[s][0],
                    self.norm[s][1],
                )
        else:
            self.norm = None

    def __len__(self):
        return self.len

    def get_dates(self, id_patch, sat):
        return self.date_range[np.where(self.date_tables[sat][id_patch] == 1)[0]]

    def __getitem__(self, item):
        id_patch = self.id_patches[item]
        output = {}
        satellites = {}
        for satellite in self.satellites:
            data = np.load(
                os.path.join(
                    self.data_root,
                    f"DATA_{satellite}",
                    f"{satellite}_{id_patch}.npy",
                )
            ).astype(np.float32)

            if self.norm is not None:
                    data = data - self.norm[satellite][0][None, :, None, None]
                    data = data / self.norm[satellite][1][None, :, None, None]

            if self.truncate_image and data.shape[0] > self.truncate_image:
                data = data[-self.truncate_image:]

            if self.pad_image and data.shape[0] < self.pad_image:
                data = pad_numpy(data, self.pad_image)

            satellites[satellite] = data.astype(np.float32)


        if self.target == "semantic":
            target = np.load(
                os.path.join(self.data_root, "ANNOTATIONS", f"TARGET_{id_patch}.npy")
            )
            target = target[0].astype(int)
            if self.class_mapping is not None:
                target = self.class_mapping(target)
        elif self.target == "instance":
            heatmap = np.load(os.path.join(self.data_root, "INSTANCE_ANNOTATIONS", f"HEATMAP_{id_patch}.npy"))
            instance_ids = np.load(os.path.join(self.data_root, "INSTANCE_ANNOTATIONS", f"INSTANCES_{id_patch}.npy"))
            zones_path = os.path.join(self.data_root, "INSTANCE_ANNOTATIONS", f"ZONES_{id_patch}.npy")
            pixel_to_object_mapping = np.load(zones_path)
            pixel_semantic_annotation = np.load(os.path.join(self.data_root, "ANNOTATIONS", f"TARGET_{id_patch}.npy"))

            if self.class_mapping is not None:
                pixel_semantic_annotation = self.class_mapping(pixel_semantic_annotation[0])
            else:
                pixel_semantic_annotation = pixel_semantic_annotation[0]

            size = np.zeros((*instance_ids.shape, 2))
            object_semantic_annotation = np.zeros(instance_ids.shape)
            for instance_id in np.unique(instance_ids):
                if instance_id != 0:
                    h = (instance_ids == instance_id).any(axis=-1).sum()
                    w = (instance_ids == instance_id).any(axis=-2).sum()
                    size[pixel_to_object_mapping == instance_id] = (h, w)
                    semantic_value = pixel_semantic_annotation[instance_ids == instance_id][0]
                    object_semantic_annotation[pixel_to_object_mapping == instance_id] = semantic_value

            target = np.concatenate(
                [
                    heatmap[:, :, None],
                    instance_ids[:, :, None],
                    pixel_to_object_mapping[:, :, None],
                    size,
                    object_semantic_annotation[:, :, None],
                    pixel_semantic_annotation[:, :, None],
                ], axis=-1).astype(np.float32)

        dates = {}
        for satellite in self.satellites:
            date = np.array(self.get_dates(id_patch, satellite))

            if self.truncate_image and len(date) > self.truncate_image:
                date = date[-self.truncate_image:]

            if self.pad_image and len(date) < self.pad_image:
                date = pad_dates_numpy(date, self.pad_image)

            dates[satellite] = torch.from_numpy(date)

        output["image"] = satellites["S2"].transpose(0, 2, 3, 1)
        output["mask"] = target

        if self.transform:
            output = self.transform(**output)

        output.update(satellites)
        output["dates"] = dates

        return output


    def plot(self, sample, suptitle=None):
        dates = sample["dates"]
        target = sample["target"]

        if "S2" not in sample:
            warnings.warn("No RGB image.", stacklevel=2)
            return None

        image_data = sample["S2"]
        date_data = dates["S2"]

        rgb_images = []
        for i in range(image_data.shape[0]):
            rgb_image = image_data[i, :3, :, :].numpy().transpose(1, 2, 0)

            rgb_min = rgb_image.min(axis=(0, 1), keepdims=True)
            rgb_max = rgb_image.max(axis=(0, 1), keepdims=True)
            denom = rgb_max - rgb_min
            denom[denom == 0] = 1
            rgb_image = (rgb_image - rgb_min) / denom

            rgb_images.append(np.clip(rgb_image, 0, 1))

        return self._plot_sample(rgb_images, date_data, target, suptitle=suptitle)

    def _plot_sample(
        self,
        images: list[np.ndarray],
        dates: torch.Tensor,
        target: torch.Tensor | None,
        suptitle: str | None = None
    ):
        num_images = len(images)
        cols = 5
        rows = (num_images + cols) // cols

        fig, ax = plt.subplots(rows, cols, figsize=(20, 4 * rows))

        for i, image in enumerate(images):
            ax[i // cols, i % cols].imshow(image)
            ax[i // cols, i % cols].set_title(f"Image {i + 1} - Day {dates[i].item()}")
            ax[i // cols, i % cols].axis("off")

        if target is not None:
            if rows * cols > num_images:
                target_ax = ax[(num_images) // cols, (num_images) % cols]
            else:
                fig.add_subplot(rows + 1, 1, 1)
                target_ax = fig.gca()

            target_ax.imshow(target.numpy(), cmap="tab20")
            target_ax.set_title("Target")
            target_ax.axis("off")

        for k in range(num_images + 1, rows * cols):
            ax[k // cols, k % cols].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        return fig
__init__(data_root, norm=True, target='semantic', folds=None, reference_date='2018-09-01', date_interval=(-200, 600), class_mapping=None, transform=None, truncate_image=None, pad_image=None, satellites=['S2'])
Parameters:
  • data_root (str) –

    Path to the dataset.

  • norm (bool, default: True ) –

    If true, images are standardised using pre-computed channel-wise means and standard deviations.

  • reference_date ((str, Format), default: '2018-09-01' ) –

    'YYYY-MM-DD'): Defines the reference date based on which all observation dates are expressed. Along with the image time series and the target tensor, this dataloader yields the sequence of observation dates (in terms of number of days since the reference date). This sequence of dates is used for instance for the positional encoding in attention based approaches.

  • target (str, default: 'semantic' ) –

    'semantic' or 'instance'. Defines which type of target is returned by the dataloader. * If 'semantic' the target tensor is a tensor containing the class of each pixel. * If 'instance' the target tensor is the concatenation of several signals, necessary to train the Parcel-as-Points module: - the centerness heatmap, - the instance ids, - the voronoi partitioning of the patch with regards to the parcels' centers, - the (height, width) size of each parcel, - the semantic label of each parcel, - the semantic label of each pixel.

  • folds (list, default: None ) –

    List of ints specifying which of the 5 official folds to load. By default (when None is specified), all folds are loaded.

  • class_mapping (dict, default: None ) –

    A dictionary to define a mapping between the default 18 class nomenclature and another class grouping. If not provided, the default class mapping is used.

  • transform (callable, default: None ) –

    A transform to apply to the loaded data (images, dates, and masks). By default, no transformation is applied.

  • truncate_image (int, default: None ) –

    Truncate the time dimension of the image to a specified number of timesteps. If None, no truncation is performed.

  • pad_image (int, default: None ) –

    Pad the time dimension of the image to a specified number of timesteps. If None, no padding is applied.

  • satellites (list, default: ['S2'] ) –

    Defines the satellites to use. If you are using PASTIS-R, you have access to Sentinel-2 imagery and Sentinel-1 observations in Ascending and Descending orbits, respectively S2, S1A, and S1D. For example, use satellites=['S2', 'S1A'] for Sentinel-2 + Sentinel-1 ascending time series, or satellites=['S2', 'S1A', 'S1D'] to retrieve all time series. If you are using PASTIS, only S2 observations are available.

Source code in terratorch/datasets/pastis.py
def __init__(
    self,
    data_root,
    norm=True,  # noqa: FBT002
    target="semantic",
    folds=None,
    reference_date="2018-09-01",
    date_interval = (-200,600),
    class_mapping=None,
    transform = None,
    truncate_image = None,
    pad_image = None,
    satellites=["S2"],  # noqa: B006
):
    """

    Args:
        data_root (str): Path to the dataset.
        norm (bool): If true, images are standardised using pre-computed
            channel-wise means and standard deviations.
        reference_date (str, Format : 'YYYY-MM-DD'): Defines the reference date
            based on which all observation dates are expressed. Along with the image
            time series and the target tensor, this dataloader yields the sequence
            of observation dates (in terms of number of days since the reference
            date). This sequence of dates is used for instance for the positional
            encoding in attention based approaches.
        target (str): 'semantic' or 'instance'. Defines which type of target is
            returned by the dataloader.
            * If 'semantic' the target tensor is a tensor containing the class of
            each pixel.
            * If 'instance' the target tensor is the concatenation of several
            signals, necessary to train the Parcel-as-Points module:
                - the centerness heatmap,
                - the instance ids,
                - the voronoi partitioning of the patch with regards to the parcels'
                centers,
                - the (height, width) size of each parcel,
                - the semantic label of each parcel,
                - the semantic label of each pixel.
        folds (list, optional): List of ints specifying which of the 5 official
            folds to load. By default (when None is specified), all folds are loaded.
        class_mapping (dict, optional): A dictionary to define a mapping between the
            default 18 class nomenclature and another class grouping. If not provided, 
            the default class mapping is used.
        transform (callable, optional): A transform to apply to the loaded data 
            (images, dates, and masks). By default, no transformation is applied.
        truncate_image (int, optional): Truncate the time dimension of the image to 
            a specified number of timesteps. If None, no truncation is performed.
        pad_image (int, optional): Pad the time dimension of the image to a specified 
            number of timesteps. If None, no padding is applied.
        satellites (list): Defines the satellites to use. If you are using PASTIS-R, you
            have access to Sentinel-2 imagery and Sentinel-1 observations in Ascending
            and Descending orbits, respectively S2, S1A, and S1D. For example, use
            satellites=['S2', 'S1A'] for Sentinel-2 + Sentinel-1 ascending time series,
            or satellites=['S2', 'S1A', 'S1D'] to retrieve all time series. If you are using
            PASTIS, only S2 observations are available.
    """
    if target not in ["semantic", "instance"]:
        msg = f"Target '{target}' not recognized. Use 'semantic', or 'instance'."
        raise ValueError(msg)
    valid_satellites = {"S2", "S1A", "S1D"}
    for sat in satellites:
        if sat not in valid_satellites:
            msg = f"Satellite '{sat}' not recognized. Valid options are {valid_satellites}."
            raise ValueError(msg)

    super().__init__()
    self.data_root = data_root
    self.norm = norm
    self.reference_date = datetime(*map(int, reference_date.split("-")), tzinfo=timezone.utc)
    self.class_mapping = (
        np.vectorize(lambda x: class_mapping[x])
        if class_mapping is not None
        else class_mapping
    )
    self.target = target
    self.satellites = satellites
    self.transform = transform
    self.truncate_image = truncate_image
    self.pad_image = pad_image
    # loads patches metadata
    self.meta_patch = gpd.read_file(os.path.join(data_root, "metadata.geojson"))
    self.meta_patch.index = self.meta_patch["ID_PATCH"].astype(int)
    self.meta_patch.sort_index(inplace=True)
    # stores table for each satalite date
    self.date_tables = {s: None for s in satellites}
    # date interval used in the PASTIS benchmark paper.
    date_interval_begin, date_interval_end = date_interval
    self.date_range = np.array(range(date_interval_begin, date_interval_end))
    for s in satellites:
        # maps patches to its observation dates
        dates = self.meta_patch[f"dates-{s}"]
        date_table = pd.DataFrame(
            index=self.meta_patch.index, columns=self.date_range, dtype=int
        )
        for pid, date_seq in dates.items():
            if type(date_seq) is str:
                date_seq = json.loads(date_seq)  # noqa: PLW2901
            # convert date to days since obersavation format
            d = pd.DataFrame().from_dict(date_seq, orient="index")
            d = d[0].apply(
                lambda x: (
                    datetime(int(str(x)[:4]), int(str(x)[4:6]), int(str(x)[6:]), tzinfo=timezone.utc)
                    - self.reference_date
                ).days
            )
            date_table.loc[pid, d.values] = 1
        date_table = date_table.fillna(0)
        self.date_tables[s] = {
            index: np.array(list(d.values()))
            for index, d in date_table.to_dict(orient="index").items()
        }

    # selects patches correspondig to selected folds
    if folds is not None:
        self.meta_patch = pd.concat(
            [self.meta_patch[self.meta_patch["Fold"] == f] for f in folds]
        )

    self.len = self.meta_patch.shape[0]
    self.id_patches = self.meta_patch.index

    # loads normalization values
    if norm:
        self.norm = {}
        for s in self.satellites:
            with open(
                os.path.join(data_root, f"NORM_{s}_patch.json")
            ) as file:
                normvals = json.loads(file.read())
            selected_folds = folds if folds is not None else range(1, 6)
            means = [normvals[f"Fold_{f}"]["mean"] for f in selected_folds]
            stds = [normvals[f"Fold_{f}"]["std"] for f in selected_folds]
            self.norm[s] = np.stack(means).mean(axis=0), np.stack(stds).mean(axis=0)
            self.norm[s] = (
                self.norm[s][0],
                self.norm[s][1],
            )
    else:
        self.norm = None

terratorch.datasets.sen1floods11

Sen1Floods11NonGeo

Bases: NonGeoDataset

NonGeo dataset implementation for sen1floods11.

Source code in terratorch/datasets/sen1floods11.py
class Sen1Floods11NonGeo(NonGeoDataset):
    """NonGeo dataset implementation for [sen1floods11](https://github.com/cloudtostreet/Sen1Floods11)."""

    all_band_names = (
            "COASTAL_AEROSOL",
            "BLUE",
            "GREEN",
            "RED",
            "RED_EDGE_1",
            "RED_EDGE_2",
            "RED_EDGE_3",
            "NIR_BROAD",
            "NIR_NARROW",
            "WATER_VAPOR",
            "CIRRUS",
            "SWIR_1",
            "SWIR_2",
    )
    rgb_bands = ("RED", "GREEN", "BLUE")
    BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}
    num_classes = 2
    splits = {"train": "train", "val": "valid", "test": "test"}
    data_dir = "v1.1/data/flood_events/HandLabeled/S2Hand"
    label_dir = "v1.1/data/flood_events/HandLabeled/LabelHand"
    split_dir = "v1.1/splits/flood_handlabeled"
    metadata_file = "v1.1/Sen1Floods11_Metadata.geojson"

    def __init__(
        self,
        data_root: str,
        split: str = "train",
        bands: Sequence[str] = BAND_SETS["all"],
        transform: A.Compose | None = None,
        constant_scale: float = 0.0001,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,  # noqa: FBT001, FBT002
    ) -> None:
        """Constructor

        Args:
            data_root (str): Path to the data root directory.
            split (str): one of 'train', 'val' or 'test'.
            bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
            transform (A.Compose | None): Albumentations transform to be applied.
                Should end with ToTensorV2(). Defaults to None, which applies ToTensorV2().
            constant_scale (float): Factor to multiply image values by. Defaults to 0.0001.
            no_data_replace (float | None): Replace nan values in input images with this value.
                If None, does no replacement. Defaults to 0.
            no_label_replace (int | None): Replace nan values in label with this value.
                If none, does no replacement. Defaults to -1.
            use_metadata (bool): whether to return metadata info (time and location).
        """
        super().__init__()
        if split not in self.splits:
            msg = f"Incorrect split '{split}', please choose one of {self.splits}."
            raise ValueError(msg)
        split_name = self.splits[split]
        self.split = split

        validate_bands(bands, self.all_band_names)
        self.bands = bands
        self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
        self.constant_scale = constant_scale
        self.data_root = Path(data_root)

        data_dir = self.data_root / self.data_dir
        label_dir = self.data_root / self.label_dir

        self.image_files = sorted(glob.glob(os.path.join(data_dir, "*_S2Hand.tif")))
        self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_dir, "*_LabelHand.tif")))

        split_file = self.data_root / self.split_dir / f"flood_{split_name}_data.txt"
        with open(split_file) as f:
            split = f.readlines()
        valid_files = {rf"{substring.strip()}" for substring in split}
        self.image_files = filter_valid_files(
            self.image_files,
            valid_files=valid_files,
            ignore_extensions=True,
            allow_substring=True,
        )
        self.segmentation_mask_files = filter_valid_files(
            self.segmentation_mask_files,
            valid_files=valid_files,
            ignore_extensions=True,
            allow_substring=True,
        )

        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.use_metadata = use_metadata
        self.metadata = None
        if self.use_metadata:
            self.metadata = geopandas.read_file(self.data_root / self.metadata_file)

        # If no transform is given, apply only to transform to torch tensor
        self.transform = transform if transform else default_transform

    def __len__(self) -> int:
        return len(self.image_files)

    def _get_date(self, index: int) -> torch.Tensor:
        file_name = self.image_files[index]
        location = os.path.basename(file_name).split("_")[0]
        if self.metadata[self.metadata["location"] == location].shape[0] != 1:
            date = pd.to_datetime("13-10-1998", dayfirst=True)
        else:
            date = pd.to_datetime(self.metadata[self.metadata["location"] == location]["s2_date"].item())

        return torch.tensor([[date.year, date.dayofyear - 1]], dtype=torch.float32)  # (n_timesteps, coords)

    def _get_coords(self, image: DataArray) -> torch.Tensor:

        center_lat = image.y[image.y.shape[0] // 2]
        center_lon = image.x[image.x.shape[0] // 2]
        lat_lon = np.asarray([center_lat, center_lon])

        return torch.tensor(lat_lon, dtype=torch.float32)

    def __getitem__(self, index: int) -> dict[str, Any]:
        image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace)

        location_coords, temporal_coords = None, None
        if self.use_metadata:
            location_coords = self._get_coords(image)
            temporal_coords = self._get_date(index)

        # to channels last
        image = image.to_numpy()
        image = np.moveaxis(image, 0, -1)

        # filter bands
        image = image[..., self.band_indices]

        output = {
            "image": image.astype(np.float32) * self.constant_scale,
            "mask": self._load_file(
                self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[0],
        }
        if self.transform:
            output = self.transform(**output)
        output["mask"] = output["mask"].long()

        if self.use_metadata:
            output["location_coords"] = location_coords
            output["temporal_coords"] = temporal_coords

        return output

    def _load_file(self, path: Path, nan_replace: int | float | None = None) -> DataArray:
        data = rioxarray.open_rasterio(path, masked=True)
        if nan_replace is not None:
            data = data.fillna(nan_replace)
        return data

    def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
        """Plot a sample from the dataset.

        Args:
            sample: a sample returned by :meth:`__getitem__`
            suptitle: optional string to use as a suptitle

        Returns:
            a matplotlib Figure with the rendered sample
        """
        num_images = 4

        rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
        if len(rgb_indices) != 3:
            msg = "Dataset doesn't contain some of the RGB bands"
            raise ValueError(msg)

        # RGB -> channels-last
        image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
        mask = sample["mask"].numpy()

        image = clip_image(image)

        if "prediction" in sample:
            prediction = sample["prediction"]
            num_images += 1
        else:
            prediction = None

        fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

        ax[0].axis("off")

        norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
        ax[1].axis("off")
        ax[1].title.set_text("Image")
        ax[1].imshow(image)

        ax[2].axis("off")
        ax[2].title.set_text("Ground Truth Mask")
        ax[2].imshow(mask, cmap="jet", norm=norm)

        ax[3].axis("off")
        ax[3].title.set_text("GT Mask on Image")
        ax[3].imshow(image)
        ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

        if "prediction" in sample:
            ax[4].title.set_text("Predicted Mask")
            ax[4].imshow(prediction, cmap="jet", norm=norm)

        cmap = plt.get_cmap("jet")
        legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
        handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
        labels = [n for k, c, n in legend_data]
        ax[0].legend(handles, labels, loc="center")
        if suptitle is not None:
            plt.suptitle(suptitle)

        return fig
__init__(data_root, split='train', bands=BAND_SETS['all'], transform=None, constant_scale=0.0001, no_data_replace=0, no_label_replace=-1, use_metadata=False)

Constructor

Parameters:
  • data_root (str) –

    Path to the data root directory.

  • split (str, default: 'train' ) –

    one of 'train', 'val' or 'test'.

  • bands (list[str], default: BAND_SETS['all'] ) –

    Bands that should be output by the dataset. Defaults to all bands.

  • transform (Compose | None, default: None ) –

    Albumentations transform to be applied. Should end with ToTensorV2(). Defaults to None, which applies ToTensorV2().

  • constant_scale (float, default: 0.0001 ) –

    Factor to multiply image values by. Defaults to 0.0001.

  • no_data_replace (float | None, default: 0 ) –

    Replace nan values in input images with this value. If None, does no replacement. Defaults to 0.

  • no_label_replace (int | None, default: -1 ) –

    Replace nan values in label with this value. If none, does no replacement. Defaults to -1.

  • use_metadata (bool, default: False ) –

    whether to return metadata info (time and location).

Source code in terratorch/datasets/sen1floods11.py
def __init__(
    self,
    data_root: str,
    split: str = "train",
    bands: Sequence[str] = BAND_SETS["all"],
    transform: A.Compose | None = None,
    constant_scale: float = 0.0001,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,  # noqa: FBT001, FBT002
) -> None:
    """Constructor

    Args:
        data_root (str): Path to the data root directory.
        split (str): one of 'train', 'val' or 'test'.
        bands (list[str]): Bands that should be output by the dataset. Defaults to all bands.
        transform (A.Compose | None): Albumentations transform to be applied.
            Should end with ToTensorV2(). Defaults to None, which applies ToTensorV2().
        constant_scale (float): Factor to multiply image values by. Defaults to 0.0001.
        no_data_replace (float | None): Replace nan values in input images with this value.
            If None, does no replacement. Defaults to 0.
        no_label_replace (int | None): Replace nan values in label with this value.
            If none, does no replacement. Defaults to -1.
        use_metadata (bool): whether to return metadata info (time and location).
    """
    super().__init__()
    if split not in self.splits:
        msg = f"Incorrect split '{split}', please choose one of {self.splits}."
        raise ValueError(msg)
    split_name = self.splits[split]
    self.split = split

    validate_bands(bands, self.all_band_names)
    self.bands = bands
    self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands])
    self.constant_scale = constant_scale
    self.data_root = Path(data_root)

    data_dir = self.data_root / self.data_dir
    label_dir = self.data_root / self.label_dir

    self.image_files = sorted(glob.glob(os.path.join(data_dir, "*_S2Hand.tif")))
    self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_dir, "*_LabelHand.tif")))

    split_file = self.data_root / self.split_dir / f"flood_{split_name}_data.txt"
    with open(split_file) as f:
        split = f.readlines()
    valid_files = {rf"{substring.strip()}" for substring in split}
    self.image_files = filter_valid_files(
        self.image_files,
        valid_files=valid_files,
        ignore_extensions=True,
        allow_substring=True,
    )
    self.segmentation_mask_files = filter_valid_files(
        self.segmentation_mask_files,
        valid_files=valid_files,
        ignore_extensions=True,
        allow_substring=True,
    )

    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.use_metadata = use_metadata
    self.metadata = None
    if self.use_metadata:
        self.metadata = geopandas.read_file(self.data_root / self.metadata_file)

    # If no transform is given, apply only to transform to torch tensor
    self.transform = transform if transform else default_transform
plot(sample, suptitle=None)

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Tensor]) –

    a sample returned by :meth:__getitem__

  • suptitle (str | None, default: None ) –

    optional string to use as a suptitle

Returns:
  • Figure

    a matplotlib Figure with the rendered sample

Source code in terratorch/datasets/sen1floods11.py
def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure:
    """Plot a sample from the dataset.

    Args:
        sample: a sample returned by :meth:`__getitem__`
        suptitle: optional string to use as a suptitle

    Returns:
        a matplotlib Figure with the rendered sample
    """
    num_images = 4

    rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
    if len(rgb_indices) != 3:
        msg = "Dataset doesn't contain some of the RGB bands"
        raise ValueError(msg)

    # RGB -> channels-last
    image = sample["image"][rgb_indices, ...].permute(1, 2, 0).numpy()
    mask = sample["mask"].numpy()

    image = clip_image(image)

    if "prediction" in sample:
        prediction = sample["prediction"]
        num_images += 1
    else:
        prediction = None

    fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed")

    ax[0].axis("off")

    norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1)
    ax[1].axis("off")
    ax[1].title.set_text("Image")
    ax[1].imshow(image)

    ax[2].axis("off")
    ax[2].title.set_text("Ground Truth Mask")
    ax[2].imshow(mask, cmap="jet", norm=norm)

    ax[3].axis("off")
    ax[3].title.set_text("GT Mask on Image")
    ax[3].imshow(image)
    ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm)

    if "prediction" in sample:
        ax[4].title.set_text("Predicted Mask")
        ax[4].imshow(prediction, cmap="jet", norm=norm)

    cmap = plt.get_cmap("jet")
    legend_data = [[i, cmap(norm(i)), str(i)] for i in range(self.num_classes)]
    handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data]
    labels = [n for k, c, n in legend_data]
    ax[0].legend(handles, labels, loc="center")
    if suptitle is not None:
        plt.suptitle(suptitle)

    return fig

terratorch.datasets.sen4agrinet

Sen4AgriNet

Bases: NonGeoDataset

Source code in terratorch/datasets/sen4agrinet.py
class Sen4AgriNet(NonGeoDataset):
    def __init__(
        self,
        data_root: str,
        bands: list[str] | None = None,
        scenario: str = "random",
        split: str = "train",
        transform: A.Compose = None,
        truncate_image: int | None = 4,
        pad_image: int | None = 4,
        spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
        seed: int = 42,
    ):
        """
        Pytorch Dataset class to load samples from the [Sen4AgriNet](https://github.com/Orion-AI-Lab/S4A) dataset, supporting
        multiple scenarios for splitting the data.

        Args:
            data_root (str): Root directory of the dataset.
            bands (list of str, optional): List of band names to load. Defaults to all available bands.
            scenario (str): Defines the splitting scenario to use. Options are:
                - 'random': Random split of the data.
                - 'spatial': Split by geographical regions (Catalonia and France).
                - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).
            split (str): Specifies the dataset split. Options are 'train', 'val', or 'test'.
            transform (albumentations.Compose, optional): Albumentations transformations to apply to the data.
            truncate_image (int, optional): Number of timesteps to truncate the time dimension of the image.
                If None, no truncation is applied. Default is 4.
            pad_image (int, optional): Number of timesteps to pad the time dimension of the image.
                If None, no padding is applied. Default is 4.
            spatial_interpolate_and_stack_temporally (bool): Whether to interpolate bands and concatenate them over time
            seed (int): Random seed used for data splitting.
        """
        self.data_root = Path(data_root) / "data"
        self.transform = transform if transform else lambda **batch: to_tensor(batch)
        self.scenario = scenario
        self.seed = seed
        self.truncate_image = truncate_image
        self.pad_image = pad_image
        self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally

        if bands is None:
            bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B10", "B11", "B12", "B8A"]
        self.bands = bands

        self.image_files = list(self.data_root.glob("**/*.nc"))

        self.train_files, self.val_files, self.test_files = self.split_data()

        if split == "train":
            self.image_files = self.train_files
        elif split == "val":
            self.image_files = self.val_files
        elif split == "test":
            self.image_files = self.test_files

    def __len__(self):
        return len(self.image_files)

    def split_data(self):
        random.seed(self.seed)

        if self.scenario == "random":
            random.shuffle(self.image_files)
            total_files = len(self.image_files)
            train_split = int(0.6 * total_files)
            val_split = int(0.8 * total_files)

            train_files = self.image_files[:train_split]
            val_files = self.image_files[train_split:val_split]
            test_files = self.image_files[val_split:]

        elif self.scenario == "spatial":
            catalonia_files = [f for f in self.image_files if any(tile in f.stem for tile in CAT_TILES)]
            france_files = [f for f in self.image_files if any(tile in f.stem for tile in FR_TILES)]

            val_split_cat = int(0.2 * len(catalonia_files))
            train_files = catalonia_files[val_split_cat:]
            val_files = catalonia_files[:val_split_cat]
            test_files = france_files

        elif self.scenario == "spatio-temporal":
            france_files = [f for f in self.image_files if any(tile in f.stem for tile in FR_TILES)]
            catalonia_files = [f for f in self.image_files if any(tile in f.stem for tile in CAT_TILES)]

            france_2019_files = [f for f in france_files if "2019" in f.stem]
            catalonia_2020_files = [f for f in catalonia_files if "2020" in f.stem]

            val_split_france_2019 = int(0.2 * len(france_2019_files))
            train_files = france_2019_files[val_split_france_2019:]
            val_files = france_2019_files[:val_split_france_2019]
            test_files = catalonia_2020_files

        return train_files, val_files, test_files


    def __getitem__(self, index: int):
        patch_file = self.image_files[index]

        with h5py.File(patch_file, "r") as patch_data:
            output = {}
            images_over_time = []
            for band in self.bands:
                band_group = patch_data[band]
                band_data = band_group[f"{band}"][:]
                time_vector = band_group["time"][:]

                sorted_indices = np.argsort(time_vector)
                band_data = band_data[sorted_indices].astype(np.float32)

                if self.truncate_image:
                    band_data = band_data[-self.truncate_image:]
                if self.pad_image:
                    band_data = pad_numpy(band_data, self.pad_image)

                if self.spatial_interpolate_and_stack_temporally:
                    band_data = torch.from_numpy(band_data)
                    band_data = band_data.clone().detach()

                    interpolated = F.interpolate(
                        band_data.unsqueeze(0), size=MAX_TEMPORAL_IMAGE_SIZE, mode="bilinear", align_corners=False
                    ).squeeze(0)
                    images_over_time.append(interpolated)
                else:
                    output[band] = band_data

            if self.spatial_interpolate_and_stack_temporally:
                images = torch.stack(images_over_time, dim=0).numpy()
                output["image"] = images

            labels = patch_data["labels"]["labels"][:].astype(int)
            parcels = patch_data["parcels"]["parcels"][:].astype(int)

        output["mask"] = labels

        image_shape = output["image"].shape[-2:]
        mask_shape = output["mask"].shape

        if image_shape != mask_shape:
            diff_h = mask_shape[0] - image_shape[0]
            diff_w = mask_shape[1] - image_shape[1]

            output["image"] = np.pad(output["image"],
                                [(0, 0), (0, 0),
                                    (diff_h // 2, diff_h - diff_h // 2),
                                    (diff_w // 2, diff_w - diff_w // 2)],
                                mode="constant", constant_values=0)

        linear_encoder = {val: i + 1 for i, val in enumerate(sorted(SELECTED_CLASSES))}
        linear_encoder[0] = 0

        output["image"] = output["image"].transpose(0, 2, 3, 1)
        output["mask"] = self.map_mask_to_discrete_classes(output["mask"], linear_encoder)

        if self.transform:
            output = self.transform(**output)

        output["parcels"] = parcels

        return output

    def plot(self, sample, suptitle=None):
        rgb_bands = ["B04", "B03", "B02"]

        if not all(band in sample for band in rgb_bands):
            warnings.warn("No RGB image.")  # noqa: B028
            return None

        rgb_images = []
        for t in range(sample["B04"].shape[0]):
            rgb_image = torch.stack([sample[band][t] for band in rgb_bands])

            # Normalization
            rgb_min = rgb_image.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values
            rgb_max = rgb_image.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values
            denom = rgb_max - rgb_min
            denom[denom == 0] = 1
            rgb_image = (rgb_image - rgb_min) / denom

            rgb_image = rgb_image.permute(1, 2, 0).numpy()
            rgb_images.append(np.clip(rgb_image, 0, 1))

        dates = torch.arange(sample["B04"].shape[0])

        return self._plot_sample(rgb_images, dates, sample.get("labels"), suptitle=suptitle)

    def _plot_sample(self, images, dates, labels=None, suptitle=None):
        num_images = len(images)
        cols = 5
        rows = (num_images + cols - 1) // cols

        fig, ax = plt.subplots(rows, cols, figsize=(20, 4 * rows))

        for i, image in enumerate(images):
            ax[i // cols, i % cols].imshow(image)
            ax[i // cols, i % cols].set_title(f"T{i+1} - Day {dates[i].item()}")
            ax[i // cols, i % cols].axis("off")

        if labels is not None:
            if rows * cols > num_images:
                target_ax = ax[(num_images) // cols, (num_images) % cols]
            else:
                fig.add_subplot(rows + 1, 1, 1)
                target_ax = fig.gca()

            target_ax.imshow(labels.numpy(), cmap="tab20")
            target_ax.set_title("Labels")
            target_ax.axis("off")

        for k in range(num_images, rows * cols):
            ax[k // cols, k % cols].axis("off")

        if suptitle:
            plt.suptitle(suptitle)

        plt.tight_layout()
        plt.show()

    def map_mask_to_discrete_classes(self, mask, encoder):
        map_func = np.vectorize(lambda x: encoder.get(x, 0))
        return map_func(mask)
__init__(data_root, bands=None, scenario='random', split='train', transform=None, truncate_image=4, pad_image=4, spatial_interpolate_and_stack_temporally=True, seed=42)

Pytorch Dataset class to load samples from the Sen4AgriNet dataset, supporting multiple scenarios for splitting the data.

Parameters:
  • data_root (str) –

    Root directory of the dataset.

  • bands (list of str, default: None ) –

    List of band names to load. Defaults to all available bands.

  • scenario (str, default: 'random' ) –

    Defines the splitting scenario to use. Options are: - 'random': Random split of the data. - 'spatial': Split by geographical regions (Catalonia and France). - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).

  • split (str, default: 'train' ) –

    Specifies the dataset split. Options are 'train', 'val', or 'test'.

  • transform (Compose, default: None ) –

    Albumentations transformations to apply to the data.

  • truncate_image (int, default: 4 ) –

    Number of timesteps to truncate the time dimension of the image. If None, no truncation is applied. Default is 4.

  • pad_image (int, default: 4 ) –

    Number of timesteps to pad the time dimension of the image. If None, no padding is applied. Default is 4.

  • spatial_interpolate_and_stack_temporally (bool, default: True ) –

    Whether to interpolate bands and concatenate them over time

  • seed (int, default: 42 ) –

    Random seed used for data splitting.

Source code in terratorch/datasets/sen4agrinet.py
def __init__(
    self,
    data_root: str,
    bands: list[str] | None = None,
    scenario: str = "random",
    split: str = "train",
    transform: A.Compose = None,
    truncate_image: int | None = 4,
    pad_image: int | None = 4,
    spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
    seed: int = 42,
):
    """
    Pytorch Dataset class to load samples from the [Sen4AgriNet](https://github.com/Orion-AI-Lab/S4A) dataset, supporting
    multiple scenarios for splitting the data.

    Args:
        data_root (str): Root directory of the dataset.
        bands (list of str, optional): List of band names to load. Defaults to all available bands.
        scenario (str): Defines the splitting scenario to use. Options are:
            - 'random': Random split of the data.
            - 'spatial': Split by geographical regions (Catalonia and France).
            - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).
        split (str): Specifies the dataset split. Options are 'train', 'val', or 'test'.
        transform (albumentations.Compose, optional): Albumentations transformations to apply to the data.
        truncate_image (int, optional): Number of timesteps to truncate the time dimension of the image.
            If None, no truncation is applied. Default is 4.
        pad_image (int, optional): Number of timesteps to pad the time dimension of the image.
            If None, no padding is applied. Default is 4.
        spatial_interpolate_and_stack_temporally (bool): Whether to interpolate bands and concatenate them over time
        seed (int): Random seed used for data splitting.
    """
    self.data_root = Path(data_root) / "data"
    self.transform = transform if transform else lambda **batch: to_tensor(batch)
    self.scenario = scenario
    self.seed = seed
    self.truncate_image = truncate_image
    self.pad_image = pad_image
    self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally

    if bands is None:
        bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B10", "B11", "B12", "B8A"]
    self.bands = bands

    self.image_files = list(self.data_root.glob("**/*.nc"))

    self.train_files, self.val_files, self.test_files = self.split_data()

    if split == "train":
        self.image_files = self.train_files
    elif split == "val":
        self.image_files = self.val_files
    elif split == "test":
        self.image_files = self.test_files

terratorch.datasets.sen4map

Sen4MapDatasetMonthlyComposites

Bases: Dataset

Sen4Map Dataset for Monthly Composites.

Dataset intended for land-cover and crop classification tasks based on monthly composites derived from multi-temporal satellite data stored in HDF5 files.

Dataset Format:

  • HDF5 files containing multi-temporal acquisitions with spectral bands (e.g., B2, B3, …, B12)
  • Composite images computed as the median across available acquisitions for each month.
  • Classification labels provided via HDF5 attributes (e.g., 'lc1') with mappings defined for:
    • Land-cover: using land_cover_classification_map
    • Crops: using crop_classification_map

Dataset Features:

  • Supports two classification tasks: "land-cover" (default) and "crops".
  • Pre-processing options include center cropping, reverse tiling, and resizing.
  • Option to save the keys HDF5 for later filtering.
  • Input channel selection via a mapping between available bands and input bands.
Source code in terratorch/datasets/sen4map.py
class Sen4MapDatasetMonthlyComposites(Dataset):
    """[Sen4Map](https://gitlab.jsc.fz-juelich.de/sdlrs/sen4map-benchmark-dataset) Dataset for Monthly Composites.

    Dataset intended for land-cover and crop classification tasks based on monthly composites
    derived from multi-temporal satellite data stored in HDF5 files.

    Dataset Format:

    * HDF5 files containing multi-temporal acquisitions with spectral bands (e.g., B2, B3, …, B12)
    * Composite images computed as the median across available acquisitions for each month.
    * Classification labels provided via HDF5 attributes (e.g., 'lc1') with mappings defined for:
        - Land-cover: using `land_cover_classification_map`
        - Crops: using `crop_classification_map`

    Dataset Features:

    * Supports two classification tasks: "land-cover" (default) and "crops".
    * Pre-processing options include center cropping, reverse tiling, and resizing.
    * Option to save the keys HDF5 for later filtering.
    * Input channel selection via a mapping between available bands and input bands.


    """
    land_cover_classification_map={'A10':0, 'A11':0, 'A12':0, 'A13':0, 
    'A20':0, 'A21':0, 'A30':0, 
    'A22':1, 'F10':1, 'F20':1, 
    'F30':1, 'F40':1,
    'E10':2, 'E20':2, 'E30':2, 'B50':2, 'B51':2, 'B52':2,
    'B53':2, 'B54':2, 'B55':2,
    'B10':3, 'B11':3, 'B12':3, 'B13':3, 'B14':3, 'B15':3,
    'B16':3, 'B17':3, 'B18':3, 'B19':3, 'B10':3, 'B20':3, 
    'B21':3, 'B22':3, 'B23':3, 'B30':3, 'B31':3, 'B32':3,
    'B33':3, 'B34':3, 'B35':3, 'B30':3, 'B36':3, 'B37':3,
    'B40':3, 'B41':3, 'B42':3, 'B43':3, 'B44':3, 'B45':3,
    'B70':3, 'B71':3, 'B72':3, 'B73':3, 'B74':3, 'B75':3,
    'B76':3, 'B77':3, 'B80':3, 'B81':3, 'B82':3, 'B83':3,
    'B84':3, 
    'BX1':3, 'BX2':3,
    'C10':4, 'C20':5, 'C21':5, 'C22':5,
    'C23':5, 'C30':5, 'C31':5, 'C32':5,
    'C33':5, 
    'CXX1':5, 'CXX2':5, 'CXX3':5, 'CXX4':5, 'CXX5':5,
    'CXX5':5, 'CXX6':5, 'CXX7':5, 'CXX8':5, 'CXX9':5,
    'CXXA':5, 'CXXB':5, 'CXXC':5, 'CXXD':5, 'CXXE':5,
    'D10':6, 'D20':6, 'D10':6,
    'G10':7, 'G11':7, 'G12':7, 'G20':7, 'G21':7, 'G22':7, 'G30':7, 
    'G40':7,
    'G50':7,
    'H10':8, 'H11':8, 'H12':8, 'H11':8,'H20':8, 'H21':8,
    'H22':8, 'H23':8, '': 9}
    #  This dictionary maps the LUCAS classes to crop classes.
    crop_classification_map = {
        "B11":0, "B12":0, "B13":0, "B14":0, "B15":0, "B16":0, "B17":0, "B18":0, "B19":0,  # Cereals
        "B21":1, "B22":1, "B23":1,  # Root Crops
        "B31":2, "B32":2, "B33":2, "B34":2, "B35":2, "B36":2, "B37":2,  # Nonpermanent Industrial Crops
        "B41":3, "B42":3, "B43":3, "B44":3, "B45":3,  # Dry Pulses, Vegetables and Flowers
        "B51":4, "B52":4, "B53":4, "B54":4,  # Fodder Crops
        "F10":5, "F20":5, "F30":5, "F40":5,  # Bareland
        "B71":6, "B72":6, "B73":6, "B74":6, "B75":6, "B76":6, "B77":6, 
        "B81":6, "B82":6, "B83":6, "B84":6, "C10":6, "C21":6, "C22":6, "C23":6, "C31":6, "C32":6, "C33":6, "D10":6, "D20":6,  # Woodland and Shrubland
        "B55":7, "E10":7, "E20":7, "E30":7,  # Grassland
    }

    def __init__(
            self,
            h5py_file_object:h5py.File,
            h5data_keys = None,
            crop_size:None|int = None,
            dataset_bands:list[HLSBands|int]|None = None,
            input_bands:list[HLSBands|int]|None = None,
            resize = False,
            resize_to = [224, 224],
            resize_interpolation = InterpolationMode.BILINEAR,
            resize_antialiasing = True,
            reverse_tile = False,
            reverse_tile_size = 3,
            save_keys_path = None,
            classification_map = "land-cover"
            ):
        """Initialize a new instance of Sen4MapDatasetMonthlyComposites.

        This dataset loads data from an HDF5 file object containing multi-temporal satellite data and computes
        monthly composite images by aggregating acquisitions (via median).

        Args:
            h5py_file_object: An open h5py.File object containing the dataset.
            h5data_keys: Optional list of keys to select a subset of data samples from the HDF5 file.
                If None, all keys are used.
            crop_size: Optional integer specifying the square crop size for the output image.
            dataset_bands: Optional list of bands available in the dataset.
            input_bands: Optional list of bands to be used as input channels.
                Must be provided along with `dataset_bands`.
            resize: Boolean flag indicating whether the image should be resized. Default is False.
            resize_to: Target dimensions [height, width] for resizing. Default is [224, 224].
            resize_interpolation: Interpolation mode used for resizing. Default is InterpolationMode.BILINEAR.
            resize_antialiasing: Boolean flag to apply antialiasing during resizing. Default is True.
            reverse_tile: Boolean flag indicating whether to apply reverse tiling to the image. Default is False.
            reverse_tile_size: Kernel size for the reverse tiling operation. Must be an odd number >= 3. Default is 3.
            save_keys_path: Optional file path to save the list of dataset keys.
            classification_map: String specifying the classification mapping to use ("land-cover" or "crops").
                Default is "land-cover".

        Raises:
            ValueError: If `input_bands` is provided without specifying `dataset_bands`.
            ValueError: If an invalid `classification_map` is provided.
        """
        self.h5data = h5py_file_object
        if h5data_keys is None:
            if classification_map == "crops": print(f"Crop classification task chosen but no keys supplied. Will fail unless dataset hdf5 files have been filtered. Either filter dataset files or create a filtered set of keys.")
            self.h5data_keys = list(self.h5data.keys())
            if save_keys_path is not None:
                with open(save_keys_path, "wb") as file:
                    pickle.dump(self.h5data_keys, file)
        else:
            self.h5data_keys = h5data_keys
        self.crop_size = crop_size
        if input_bands and not dataset_bands:
            raise ValueError(f"input_bands was provided without specifying the dataset_bands")
        # self.dataset_bands = dataset_bands
        # self.input_bands = input_bands
        if input_bands and dataset_bands:
            self.input_channels = [dataset_bands.index(band_ind) for band_ind in input_bands if band_ind in dataset_bands]
        else: self.input_channels = None

        classification_maps = {"land-cover": Sen4MapDatasetMonthlyComposites.land_cover_classification_map,
                               "crops": Sen4MapDatasetMonthlyComposites.crop_classification_map}
        if classification_map not in classification_maps.keys():
            raise ValueError(f"Provided classification_map of: {classification_map}, is not from the list of valid ones: {classification_maps}")
        self.classification_map = classification_maps[classification_map]

        self.resize = resize
        self.resize_to = resize_to
        self.resize_interpolation = resize_interpolation
        self.resize_antialiasing = resize_antialiasing

        self.reverse_tile = reverse_tile
        self.reverse_tile_size = reverse_tile_size

    def __getitem__(self, index):
        # we can call dataset with an index, eg. dataset[0]
        im = self.h5data[self.h5data_keys[index]]
        Image, Label = self.get_data(im)
        Image = self.min_max_normalize(Image, [67.0, 122.0, 93.27, 158.5, 160.77, 174.27, 162.27, 149.0, 84.5, 66.27 ],
                                    [2089.0, 2598.45, 3214.5, 3620.45, 4033.61, 4613.0, 4825.45, 4945.72, 5140.84, 4414.45])

        Image = Image.clip(0,1)
        Label = torch.LongTensor(Label)
        if self.input_channels:
            Image = Image[self.input_channels, ...]

        return {"image":Image, "label":Label}

    def __len__(self):
        return len(self.h5data_keys)

    def get_data(self, im):
        mask = im['SCL'] < 9

        B2= np.where(mask==1, im['B2'], 0)
        B3= np.where(mask==1, im['B3'], 0)
        B4= np.where(mask==1, im['B4'], 0)
        B5= np.where(mask==1, im['B5'], 0)
        B6= np.where(mask==1, im['B6'], 0)
        B7= np.where(mask==1, im['B7'], 0)
        B8= np.where(mask==1, im['B8'], 0)
        B8A= np.where(mask==1, im['B8A'], 0)
        B11= np.where(mask==1, im['B11'], 0)
        B12= np.where(mask==1, im['B12'], 0)
        Image = np.stack((B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12), axis=0, dtype="float32")
        Image = np.moveaxis(Image, [0],[1])
        Image = torch.from_numpy(Image)

        # Composites:
        n1= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201801' in s]
        n2= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201802' in s]
        n3= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201803' in s]
        n4= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201804' in s]
        n5= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201805' in s]
        n6= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201806' in s]
        n7= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201807' in s]
        n8= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201808' in s]
        n9= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201809' in s]
        n10= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201810' in s]
        n11= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201811' in s]
        n12= [i for i, s in enumerate(im.attrs['Image_ID'].tolist()) if '201812' in s]


        Jan= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n1 else n1
        Feb= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n2 else n2
        Mar= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n3 else n3
        Apr= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n4 else n4
        May= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n5 else n5
        Jun= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n6 else n6
        Jul= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n7 else n7
        Aug= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n8 else n8
        Sep= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n9 else n9
        Oct= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n10 else n10
        Nov= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n11 else n11
        Dec= n1 + n2 + n3 + n4 + n5 + n6 + n7 + n8 + n9 + n10 + n11 + n12 if not n12 else n12

        month_indices = [Jan, Feb, Mar, Apr, May, Jun, Jul, Aug, Sep, Oct, Nov, Dec]

        month_medians = [torch.stack([Image[month_indices[i][j]] for j in range(len(month_indices[i]))]).median(dim=0).values for i in range(12)]


        Image = torch.stack(month_medians, dim=0)
        Image = torch.moveaxis(Image, 0, 1)

        if self.crop_size: Image = self.crop_center(Image, self.crop_size, self.crop_size)
        if self.reverse_tile:
            Image = self.reverse_tiling_pytorch(Image, kernel_size=self.reverse_tile_size)
        if self.resize:
            Image = resize(Image, size=self.resize_to, interpolation=self.resize_interpolation, antialias=self.resize_antialiasing)

        Label = im.attrs['lc1']
        Label = self.classification_map[Label]
        Label = np.array(Label)
        Label = Label.astype('float32')

        return Image, Label

    def crop_center(self, img_b:torch.Tensor, cropx, cropy) -> torch.Tensor:
        c, t, y, x = img_b.shape
        startx = x//2-(cropx//2)
        starty = y//2-(cropy//2)    
        return img_b[0:c, 0:t, starty:starty+cropy, startx:startx+cropx]


    def reverse_tiling_pytorch(self, img_tensor: torch.Tensor, kernel_size: int=3):
        """
        Upscales an image where every pixel is expanded into `kernel_size`*`kernel_size` pixels.
        Used to test whether the benefit of resizing images to the pre-trained size comes from the bilnearly interpolated pixels,
        or if the same would be realized with no interpolated pixels.
        """
        assert kernel_size % 2 == 1
        assert kernel_size >= 3
        padding = (kernel_size - 1) // 2
        # img_tensor shape: (batch_size, channels, H, W)
        batch_size, channels, H, W = img_tensor.shape
        # Unfold: Extract 3x3 patches with padding of 1 to cover borders
        img_tensor = F.pad(img_tensor, pad=(padding,padding,padding,padding), mode="replicate")
        patches = F.unfold(img_tensor, kernel_size=kernel_size, padding=0)  # Shape: (batch_size, channels*9, H*W)
        # Reshape to organize the 9 values from each 3x3 neighborhood
        patches = patches.view(batch_size, channels, kernel_size*kernel_size, H, W)  # Shape: (batch_size, channels, 9, H, W)
        # Rearrange the patches into (batch_size, channels, 3, 3, H, W)
        patches = patches.view(batch_size, channels, kernel_size, kernel_size, H, W)
        # Permute to have the spatial dimensions first and unfold them
        patches = patches.permute(0, 1, 4, 2, 5, 3)  # Shape: (batch_size, channels, H, 3, W, 3)
        # Reshape to get the final expanded image of shape (batch_size, channels, H*3, W*3)
        expanded_img = patches.reshape(batch_size, channels, H * kernel_size, W * kernel_size)
        return expanded_img

    def min_max_normalize(self, tensor:torch.Tensor, q_low:list[float], q_hi:list[float]) -> torch.Tensor:
        dtype = tensor.dtype
        q_low = torch.as_tensor(q_low, dtype=dtype, device=tensor.device)
        q_hi = torch.as_tensor(q_hi, dtype=dtype, device=tensor.device)
        x = torch.tensor(-12.0)
        y = torch.exp(x)
        tensor.sub_(q_low[:, None, None, None]).div_((q_hi[:, None, None, None].sub_(q_low[:, None, None, None])).add(y))
        return tensor
__init__(h5py_file_object, h5data_keys=None, crop_size=None, dataset_bands=None, input_bands=None, resize=False, resize_to=[224, 224], resize_interpolation=InterpolationMode.BILINEAR, resize_antialiasing=True, reverse_tile=False, reverse_tile_size=3, save_keys_path=None, classification_map='land-cover')

Initialize a new instance of Sen4MapDatasetMonthlyComposites.

This dataset loads data from an HDF5 file object containing multi-temporal satellite data and computes monthly composite images by aggregating acquisitions (via median).

Parameters:
  • h5py_file_object (File) –

    An open h5py.File object containing the dataset.

  • h5data_keys

    Optional list of keys to select a subset of data samples from the HDF5 file. If None, all keys are used.

  • crop_size (None | int, default: None ) –

    Optional integer specifying the square crop size for the output image.

  • dataset_bands (list[HLSBands | int] | None, default: None ) –

    Optional list of bands available in the dataset.

  • input_bands (list[HLSBands | int] | None, default: None ) –

    Optional list of bands to be used as input channels. Must be provided along with dataset_bands.

  • resize

    Boolean flag indicating whether the image should be resized. Default is False.

  • resize_to

    Target dimensions [height, width] for resizing. Default is [224, 224].

  • resize_interpolation

    Interpolation mode used for resizing. Default is InterpolationMode.BILINEAR.

  • resize_antialiasing

    Boolean flag to apply antialiasing during resizing. Default is True.

  • reverse_tile

    Boolean flag indicating whether to apply reverse tiling to the image. Default is False.

  • reverse_tile_size

    Kernel size for the reverse tiling operation. Must be an odd number >= 3. Default is 3.

  • save_keys_path

    Optional file path to save the list of dataset keys.

  • classification_map

    String specifying the classification mapping to use ("land-cover" or "crops"). Default is "land-cover".

Raises:
  • ValueError

    If input_bands is provided without specifying dataset_bands.

  • ValueError

    If an invalid classification_map is provided.

Source code in terratorch/datasets/sen4map.py
def __init__(
        self,
        h5py_file_object:h5py.File,
        h5data_keys = None,
        crop_size:None|int = None,
        dataset_bands:list[HLSBands|int]|None = None,
        input_bands:list[HLSBands|int]|None = None,
        resize = False,
        resize_to = [224, 224],
        resize_interpolation = InterpolationMode.BILINEAR,
        resize_antialiasing = True,
        reverse_tile = False,
        reverse_tile_size = 3,
        save_keys_path = None,
        classification_map = "land-cover"
        ):
    """Initialize a new instance of Sen4MapDatasetMonthlyComposites.

    This dataset loads data from an HDF5 file object containing multi-temporal satellite data and computes
    monthly composite images by aggregating acquisitions (via median).

    Args:
        h5py_file_object: An open h5py.File object containing the dataset.
        h5data_keys: Optional list of keys to select a subset of data samples from the HDF5 file.
            If None, all keys are used.
        crop_size: Optional integer specifying the square crop size for the output image.
        dataset_bands: Optional list of bands available in the dataset.
        input_bands: Optional list of bands to be used as input channels.
            Must be provided along with `dataset_bands`.
        resize: Boolean flag indicating whether the image should be resized. Default is False.
        resize_to: Target dimensions [height, width] for resizing. Default is [224, 224].
        resize_interpolation: Interpolation mode used for resizing. Default is InterpolationMode.BILINEAR.
        resize_antialiasing: Boolean flag to apply antialiasing during resizing. Default is True.
        reverse_tile: Boolean flag indicating whether to apply reverse tiling to the image. Default is False.
        reverse_tile_size: Kernel size for the reverse tiling operation. Must be an odd number >= 3. Default is 3.
        save_keys_path: Optional file path to save the list of dataset keys.
        classification_map: String specifying the classification mapping to use ("land-cover" or "crops").
            Default is "land-cover".

    Raises:
        ValueError: If `input_bands` is provided without specifying `dataset_bands`.
        ValueError: If an invalid `classification_map` is provided.
    """
    self.h5data = h5py_file_object
    if h5data_keys is None:
        if classification_map == "crops": print(f"Crop classification task chosen but no keys supplied. Will fail unless dataset hdf5 files have been filtered. Either filter dataset files or create a filtered set of keys.")
        self.h5data_keys = list(self.h5data.keys())
        if save_keys_path is not None:
            with open(save_keys_path, "wb") as file:
                pickle.dump(self.h5data_keys, file)
    else:
        self.h5data_keys = h5data_keys
    self.crop_size = crop_size
    if input_bands and not dataset_bands:
        raise ValueError(f"input_bands was provided without specifying the dataset_bands")
    # self.dataset_bands = dataset_bands
    # self.input_bands = input_bands
    if input_bands and dataset_bands:
        self.input_channels = [dataset_bands.index(band_ind) for band_ind in input_bands if band_ind in dataset_bands]
    else: self.input_channels = None

    classification_maps = {"land-cover": Sen4MapDatasetMonthlyComposites.land_cover_classification_map,
                           "crops": Sen4MapDatasetMonthlyComposites.crop_classification_map}
    if classification_map not in classification_maps.keys():
        raise ValueError(f"Provided classification_map of: {classification_map}, is not from the list of valid ones: {classification_maps}")
    self.classification_map = classification_maps[classification_map]

    self.resize = resize
    self.resize_to = resize_to
    self.resize_interpolation = resize_interpolation
    self.resize_antialiasing = resize_antialiasing

    self.reverse_tile = reverse_tile
    self.reverse_tile_size = reverse_tile_size
reverse_tiling_pytorch(img_tensor, kernel_size=3)

Upscales an image where every pixel is expanded into kernel_size*kernel_size pixels. Used to test whether the benefit of resizing images to the pre-trained size comes from the bilnearly interpolated pixels, or if the same would be realized with no interpolated pixels.

Source code in terratorch/datasets/sen4map.py
def reverse_tiling_pytorch(self, img_tensor: torch.Tensor, kernel_size: int=3):
    """
    Upscales an image where every pixel is expanded into `kernel_size`*`kernel_size` pixels.
    Used to test whether the benefit of resizing images to the pre-trained size comes from the bilnearly interpolated pixels,
    or if the same would be realized with no interpolated pixels.
    """
    assert kernel_size % 2 == 1
    assert kernel_size >= 3
    padding = (kernel_size - 1) // 2
    # img_tensor shape: (batch_size, channels, H, W)
    batch_size, channels, H, W = img_tensor.shape
    # Unfold: Extract 3x3 patches with padding of 1 to cover borders
    img_tensor = F.pad(img_tensor, pad=(padding,padding,padding,padding), mode="replicate")
    patches = F.unfold(img_tensor, kernel_size=kernel_size, padding=0)  # Shape: (batch_size, channels*9, H*W)
    # Reshape to organize the 9 values from each 3x3 neighborhood
    patches = patches.view(batch_size, channels, kernel_size*kernel_size, H, W)  # Shape: (batch_size, channels, 9, H, W)
    # Rearrange the patches into (batch_size, channels, 3, 3, H, W)
    patches = patches.view(batch_size, channels, kernel_size, kernel_size, H, W)
    # Permute to have the spatial dimensions first and unfold them
    patches = patches.permute(0, 1, 4, 2, 5, 3)  # Shape: (batch_size, channels, H, 3, W, 3)
    # Reshape to get the final expanded image of shape (batch_size, channels, H*3, W*3)
    expanded_img = patches.reshape(batch_size, channels, H * kernel_size, W * kernel_size)
    return expanded_img

Datamodules

terratorch.datamodules.biomassters

BioMasstersNonGeoDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for BioMassters datamodule.

Source code in terratorch/datamodules/biomassters.py
class BioMasstersNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for BioMassters datamodule."""

    default_metadata_filename = "The_BioMassters_-_features_metadata.csv.csv"

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: dict[str, Sequence[str]] | Sequence[str] = BioMasstersNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        drop_last: bool = True,
        sensors: Sequence[str] = ["S1", "S2"],
        as_time_series: bool = False,
        metadata_filename: str = default_metadata_filename,
        max_cloud_percentage: float | None = None,
        max_red_mean: float | None = None,
        include_corrupt: bool = True,
        subset: float = 1,
        seed: int = 42,
        use_four_frames: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the DataModule for the non-geospatial BioMassters datamodule.

        Args:
            data_root (str): Root directory containing the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (dict[str, Sequence[str]] | Sequence[str], optional): Band configuration; either a dict mapping sensors to bands or a list for the first sensor.
                Defaults to BioMasstersNonGeo.all_band_names
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            aug (AugmentationSequential, optional): Augmentation or normalization to apply. Defaults to normalization if not provided.
            drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
            sensors (Sequence[str], optional): List of sensors to use (e.g., ["S1", "S2"]). Defaults to ["S1", "S2"].
            as_time_series (bool, optional): Whether to treat data as a time series. Defaults to False.
            metadata_filename (str, optional): Metadata filename. Defaults to "The_BioMassters_-_features_metadata.csv.csv".
            max_cloud_percentage (float | None, optional): Maximum allowed cloud percentage. Defaults to None.
            max_red_mean (float | None, optional): Maximum allowed red band mean. Defaults to None.
            include_corrupt (bool, optional): Whether to include corrupt data. Defaults to True.
            subset (float, optional): Fraction of the dataset to use. Defaults to 1.
            seed (int, optional): Random seed for reproducibility. Defaults to 42.
            use_four_frames (bool, optional): Whether to use a four frames configuration. Defaults to False.
            **kwargs: Additional keyword arguments.

        Returns:
            None.
        """
        super().__init__(BioMasstersNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root
        self.sensors = sensors
        if isinstance(bands, dict):
            self.bands = bands
        else:
            sens = sensors[0]
            self.bands = {sens: bands}

        self.means = {}
        self.stds = {}
        for sensor in self.sensors:
            self.means[sensor] = [MEANS[sensor][band] for band in self.bands[sensor]]
            self.stds[sensor] = [STDS[sensor][band] for band in self.bands[sensor]]

        self.mask_mean = MEANS["AGBM"]
        self.mask_std = STDS["AGBM"]
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        if len(sensors) == 1:
            self.aug = Normalize(self.means[sensors[0]], self.stds[sensors[0]]) if aug is None else aug
        else:
            MultimodalNormalize(self.means, self.stds) if aug is None else aug
        self.drop_last = drop_last
        self.as_time_series = as_time_series
        self.metadata_filename = metadata_filename
        self.max_cloud_percentage = max_cloud_percentage
        self.max_red_mean = max_red_mean
        self.include_corrupt = include_corrupt
        self.subset = subset
        self.seed = seed
        self.use_four_frames = use_four_frames

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="test",
                root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                mask_mean=self.mask_mean,
                mask_std=self.mask_std,
                sensors=self.sensors,
                as_time_series=self.as_time_series,
                metadata_filename=self.metadata_filename,
                max_cloud_percentage=self.max_cloud_percentage,
                max_red_mean=self.max_red_mean,
                include_corrupt=self.include_corrupt,
                subset=self.subset,
                seed=self.seed,
                use_four_frames=self.use_four_frames,
            )

    def _dataloader_factory(self, split: str):
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split =="train" and self.drop_last,
        )
__init__(data_root, batch_size=4, num_workers=0, bands=BioMasstersNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, aug=None, drop_last=True, sensors=['S1', 'S2'], as_time_series=False, metadata_filename=default_metadata_filename, max_cloud_percentage=None, max_red_mean=None, include_corrupt=True, subset=1, seed=42, use_four_frames=False, **kwargs)

Initializes the DataModule for the non-geospatial BioMassters datamodule.

Parameters:
  • data_root (str) –

    Root directory containing the dataset.

  • batch_size (int, default: 4 ) –

    Batch size for DataLoaders. Defaults to 4.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • bands (dict[str, Sequence[str]] | Sequence[str], default: all_band_names ) –

    Band configuration; either a dict mapping sensors to bands or a list for the first sensor. Defaults to BioMasstersNonGeo.all_band_names

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training data.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation data.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing data.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction data.

  • aug (AugmentationSequential, default: None ) –

    Augmentation or normalization to apply. Defaults to normalization if not provided.

  • drop_last (bool, default: True ) –

    Whether to drop the last incomplete batch. Defaults to True.

  • sensors (Sequence[str], default: ['S1', 'S2'] ) –

    List of sensors to use (e.g., ["S1", "S2"]). Defaults to ["S1", "S2"].

  • as_time_series (bool, default: False ) –

    Whether to treat data as a time series. Defaults to False.

  • metadata_filename (str, default: default_metadata_filename ) –

    Metadata filename. Defaults to "The_BioMassters_-_features_metadata.csv.csv".

  • max_cloud_percentage (float | None, default: None ) –

    Maximum allowed cloud percentage. Defaults to None.

  • max_red_mean (float | None, default: None ) –

    Maximum allowed red band mean. Defaults to None.

  • include_corrupt (bool, default: True ) –

    Whether to include corrupt data. Defaults to True.

  • subset (float, default: 1 ) –

    Fraction of the dataset to use. Defaults to 1.

  • seed (int, default: 42 ) –

    Random seed for reproducibility. Defaults to 42.

  • use_four_frames (bool, default: False ) –

    Whether to use a four frames configuration. Defaults to False.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Returns:
  • None

    None.

Source code in terratorch/datamodules/biomassters.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: dict[str, Sequence[str]] | Sequence[str] = BioMasstersNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    drop_last: bool = True,
    sensors: Sequence[str] = ["S1", "S2"],
    as_time_series: bool = False,
    metadata_filename: str = default_metadata_filename,
    max_cloud_percentage: float | None = None,
    max_red_mean: float | None = None,
    include_corrupt: bool = True,
    subset: float = 1,
    seed: int = 42,
    use_four_frames: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the DataModule for the non-geospatial BioMassters datamodule.

    Args:
        data_root (str): Root directory containing the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (dict[str, Sequence[str]] | Sequence[str], optional): Band configuration; either a dict mapping sensors to bands or a list for the first sensor.
            Defaults to BioMasstersNonGeo.all_band_names
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        aug (AugmentationSequential, optional): Augmentation or normalization to apply. Defaults to normalization if not provided.
        drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
        sensors (Sequence[str], optional): List of sensors to use (e.g., ["S1", "S2"]). Defaults to ["S1", "S2"].
        as_time_series (bool, optional): Whether to treat data as a time series. Defaults to False.
        metadata_filename (str, optional): Metadata filename. Defaults to "The_BioMassters_-_features_metadata.csv.csv".
        max_cloud_percentage (float | None, optional): Maximum allowed cloud percentage. Defaults to None.
        max_red_mean (float | None, optional): Maximum allowed red band mean. Defaults to None.
        include_corrupt (bool, optional): Whether to include corrupt data. Defaults to True.
        subset (float, optional): Fraction of the dataset to use. Defaults to 1.
        seed (int, optional): Random seed for reproducibility. Defaults to 42.
        use_four_frames (bool, optional): Whether to use a four frames configuration. Defaults to False.
        **kwargs: Additional keyword arguments.

    Returns:
        None.
    """
    super().__init__(BioMasstersNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root
    self.sensors = sensors
    if isinstance(bands, dict):
        self.bands = bands
    else:
        sens = sensors[0]
        self.bands = {sens: bands}

    self.means = {}
    self.stds = {}
    for sensor in self.sensors:
        self.means[sensor] = [MEANS[sensor][band] for band in self.bands[sensor]]
        self.stds[sensor] = [STDS[sensor][band] for band in self.bands[sensor]]

    self.mask_mean = MEANS["AGBM"]
    self.mask_std = STDS["AGBM"]
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    if len(sensors) == 1:
        self.aug = Normalize(self.means[sensors[0]], self.stds[sensors[0]]) if aug is None else aug
    else:
        MultimodalNormalize(self.means, self.stds) if aug is None else aug
    self.drop_last = drop_last
    self.as_time_series = as_time_series
    self.metadata_filename = metadata_filename
    self.max_cloud_percentage = max_cloud_percentage
    self.max_red_mean = max_red_mean
    self.include_corrupt = include_corrupt
    self.subset = subset
    self.seed = seed
    self.use_four_frames = use_four_frames
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/biomassters.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            mask_mean=self.mask_mean,
            mask_std=self.mask_std,
            sensors=self.sensors,
            as_time_series=self.as_time_series,
            metadata_filename=self.metadata_filename,
            max_cloud_percentage=self.max_cloud_percentage,
            max_red_mean=self.max_red_mean,
            include_corrupt=self.include_corrupt,
            subset=self.subset,
            seed=self.seed,
            use_four_frames=self.use_four_frames,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="test",
            root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            mask_mean=self.mask_mean,
            mask_std=self.mask_std,
            sensors=self.sensors,
            as_time_series=self.as_time_series,
            metadata_filename=self.metadata_filename,
            max_cloud_percentage=self.max_cloud_percentage,
            max_red_mean=self.max_red_mean,
            include_corrupt=self.include_corrupt,
            subset=self.subset,
            seed=self.seed,
            use_four_frames=self.use_four_frames,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",
            root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            mask_mean=self.mask_mean,
            mask_std=self.mask_std,
            sensors=self.sensors,
            as_time_series=self.as_time_series,
            metadata_filename=self.metadata_filename,
            max_cloud_percentage=self.max_cloud_percentage,
            max_red_mean=self.max_red_mean,
            include_corrupt=self.include_corrupt,
            subset=self.subset,
            seed=self.seed,
            use_four_frames=self.use_four_frames,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",
            root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            mask_mean=self.mask_mean,
            mask_std=self.mask_std,
            sensors=self.sensors,
            as_time_series=self.as_time_series,
            metadata_filename=self.metadata_filename,
            max_cloud_percentage=self.max_cloud_percentage,
            max_red_mean=self.max_red_mean,
            include_corrupt=self.include_corrupt,
            subset=self.subset,
            seed=self.seed,
            use_four_frames=self.use_four_frames,
        )

terratorch.datamodules.burn_intensity

BurnIntensityNonGeoDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for BurnIntensity datamodule.

Source code in terratorch/datamodules/burn_intensity.py
class BurnIntensityNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for BurnIntensity datamodule."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = BurnIntensityNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        use_full_data: bool = True,
        no_data_replace: float | None = 0.0001,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the DataModule for the BurnIntensity non-geospatial datamodule.

        Args:
            data_root (str): Root directory of the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of bands to use. Defaults to BurnIntensityNonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
            use_full_data (bool, optional): Whether to use the full dataset or data with less than 25 percent zeros. Defaults to True.
            no_data_replace (float | None, optional): Value to replace missing data. Defaults to 0.0001.
            no_label_replace (int | None, optional): Value to replace missing labels. Defaults to -1.
            use_metadata (bool): Whether to return metadata info (time and location).
            **kwargs: Additional keyword arguments.
        """
        super().__init__(BurnIntensityNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = [MEANS[b] for b in bands]
        stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = NormalizeWithTimesteps(means, stds)
        self.use_full_data = use_full_data
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                use_full_data=self.use_full_data,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
__init__(data_root, batch_size=4, num_workers=0, bands=BurnIntensityNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, use_full_data=True, no_data_replace=0.0001, no_label_replace=-1, use_metadata=False, **kwargs)

Initializes the DataModule for the BurnIntensity non-geospatial datamodule.

Parameters:
  • data_root (str) –

    Root directory of the dataset.

  • batch_size (int, default: 4 ) –

    Batch size for DataLoaders. Defaults to 4.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • bands (Sequence[str], default: all_band_names ) –

    List of bands to use. Defaults to BurnIntensityNonGeo.all_band_names.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction.

  • use_full_data (bool, default: True ) –

    Whether to use the full dataset or data with less than 25 percent zeros. Defaults to True.

  • no_data_replace (float | None, default: 0.0001 ) –

    Value to replace missing data. Defaults to 0.0001.

  • no_label_replace (int | None, default: -1 ) –

    Value to replace missing labels. Defaults to -1.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info (time and location).

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/burn_intensity.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = BurnIntensityNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    use_full_data: bool = True,
    no_data_replace: float | None = 0.0001,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the DataModule for the BurnIntensity non-geospatial datamodule.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of bands to use. Defaults to BurnIntensityNonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
        use_full_data (bool, optional): Whether to use the full dataset or data with less than 25 percent zeros. Defaults to True.
        no_data_replace (float | None, optional): Value to replace missing data. Defaults to 0.0001.
        no_label_replace (int | None, optional): Value to replace missing labels. Defaults to -1.
        use_metadata (bool): Whether to return metadata info (time and location).
        **kwargs: Additional keyword arguments.
    """
    super().__init__(BurnIntensityNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    means = [MEANS[b] for b in bands]
    stds = [STDS[b] for b in bands]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = NormalizeWithTimesteps(means, stds)
    self.use_full_data = use_full_data
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.use_metadata = use_metadata
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/burn_intensity.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            use_full_data=self.use_full_data,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            use_full_data=self.use_full_data,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            use_full_data=self.use_full_data,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            use_full_data=self.use_full_data,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )

terratorch.datamodules.carbonflux

CarbonFluxNonGeoDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Carbon FLux dataset.

Source code in terratorch/datamodules/carbonflux.py
class CarbonFluxNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Carbon FLux dataset."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = CarbonFluxNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        no_data_replace: float | None = 0.0001,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the CarbonFluxNonGeoDataModule.

        Args:
            data_root (str): Root directory of the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of bands to use. Defaults to CarbonFluxNonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            aug (AugmentationSequential, optional): Augmentation sequence; if None, applies multimodal normalization.
            no_data_replace (float | None, optional): Value to replace missing data. Defaults to 0.0001.
            use_metadata (bool): Whether to return metadata info.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(CarbonFluxNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = {
            m: ([MEANS[m][band] for band in bands] if m == "image" else MEANS[m])
            for m in MEANS.keys()
        }
        stds = {
            m: ([STDS[m][band] for band in bands] if m == "image" else STDS[m])
            for m in STDS.keys()
        }
        self.mask_means = MEANS["mask"]
        self.mask_std = STDS["mask"]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = MultimodalNormalize(means, stds) if aug is None else aug
        self.no_data_replace = no_data_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                gpp_mean=self.mask_means,
                gpp_std=self.mask_std,
                no_data_replace=self.no_data_replace,
                use_metadata=self.use_metadata,
            )
__init__(data_root, batch_size=4, num_workers=0, bands=CarbonFluxNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, aug=None, no_data_replace=0.0001, use_metadata=False, **kwargs)

Initializes the CarbonFluxNonGeoDataModule.

Parameters:
  • data_root (str) –

    Root directory of the dataset.

  • batch_size (int, default: 4 ) –

    Batch size for DataLoaders. Defaults to 4.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • bands (Sequence[str], default: all_band_names ) –

    List of bands to use. Defaults to CarbonFluxNonGeo.all_band_names.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training data.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation data.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing data.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction data.

  • aug (AugmentationSequential, default: None ) –

    Augmentation sequence; if None, applies multimodal normalization.

  • no_data_replace (float | None, default: 0.0001 ) –

    Value to replace missing data. Defaults to 0.0001.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/carbonflux.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = CarbonFluxNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    no_data_replace: float | None = 0.0001,
    use_metadata: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the CarbonFluxNonGeoDataModule.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of bands to use. Defaults to CarbonFluxNonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        aug (AugmentationSequential, optional): Augmentation sequence; if None, applies multimodal normalization.
        no_data_replace (float | None, optional): Value to replace missing data. Defaults to 0.0001.
        use_metadata (bool): Whether to return metadata info.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(CarbonFluxNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    means = {
        m: ([MEANS[m][band] for band in bands] if m == "image" else MEANS[m])
        for m in MEANS.keys()
    }
    stds = {
        m: ([STDS[m][band] for band in bands] if m == "image" else STDS[m])
        for m in STDS.keys()
    }
    self.mask_means = MEANS["mask"]
    self.mask_std = STDS["mask"]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = MultimodalNormalize(means, stds) if aug is None else aug
    self.no_data_replace = no_data_replace
    self.use_metadata = use_metadata
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/carbonflux.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            gpp_mean=self.mask_means,
            gpp_std=self.mask_std,
            no_data_replace=self.no_data_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            gpp_mean=self.mask_means,
            gpp_std=self.mask_std,
            no_data_replace=self.no_data_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            gpp_mean=self.mask_means,
            gpp_std=self.mask_std,
            no_data_replace=self.no_data_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            gpp_mean=self.mask_means,
            gpp_std=self.mask_std,
            no_data_replace=self.no_data_replace,
            use_metadata=self.use_metadata,
        )

terratorch.datamodules.forestnet

ForestNetNonGeoDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Landslide4Sense dataset.

Source code in terratorch/datamodules/forestnet.py
class ForestNetNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Landslide4Sense dataset."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        label_map: dict[str, int] = ForestNetNonGeo.default_label_map,
        bands: Sequence[str] = ForestNetNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        fraction: float = 1.0,
        aug: AugmentationSequential = None,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the ForestNetNonGeoDataModule.

        Args:
            data_root (str): Directory containing the dataset.
            batch_size (int, optional): Batch size for data loaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            label_map (dict[str, int], optional): Mapping of labels to integers. Defaults to ForestNetNonGeo.default_label_map.
            bands (Sequence[str], optional): List of band names to use. Defaults to ForestNetNonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
            fraction (float, optional): Fraction of data to use. Defaults to 1.0.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline; if None, uses Normalize.
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(ForestNetNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        self.means = [MEANS[b] for b in bands]
        self.stds = [STDS[b] for b in bands]
        self.label_map = label_map
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = Normalize(self.means, self.stds) if aug is None else aug
        self.fraction = fraction
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.train_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.val_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.test_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                label_map=self.label_map,
                transform=self.predict_transform,
                bands=self.bands,
                fraction=self.fraction,
                use_metadata=self.use_metadata,
            )
__init__(data_root, batch_size=4, num_workers=0, label_map=ForestNetNonGeo.default_label_map, bands=ForestNetNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, fraction=1.0, aug=None, use_metadata=False, **kwargs)

Initializes the ForestNetNonGeoDataModule.

Parameters:
  • data_root (str) –

    Directory containing the dataset.

  • batch_size (int, default: 4 ) –

    Batch size for data loaders. Defaults to 4.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • label_map (dict[str, int], default: default_label_map ) –

    Mapping of labels to integers. Defaults to ForestNetNonGeo.default_label_map.

  • bands (Sequence[str], default: all_band_names ) –

    List of band names to use. Defaults to ForestNetNonGeo.all_band_names.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training data.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation data.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing data.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction.

  • fraction (float, default: 1.0 ) –

    Fraction of data to use. Defaults to 1.0.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline; if None, uses Normalize.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/forestnet.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    label_map: dict[str, int] = ForestNetNonGeo.default_label_map,
    bands: Sequence[str] = ForestNetNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    fraction: float = 1.0,
    aug: AugmentationSequential = None,
    use_metadata: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the ForestNetNonGeoDataModule.

    Args:
        data_root (str): Directory containing the dataset.
        batch_size (int, optional): Batch size for data loaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        label_map (dict[str, int], optional): Mapping of labels to integers. Defaults to ForestNetNonGeo.default_label_map.
        bands (Sequence[str], optional): List of band names to use. Defaults to ForestNetNonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
        fraction (float, optional): Fraction of data to use. Defaults to 1.0.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline; if None, uses Normalize.
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(ForestNetNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    self.means = [MEANS[b] for b in bands]
    self.stds = [STDS[b] for b in bands]
    self.label_map = label_map
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = Normalize(self.means, self.stds) if aug is None else aug
    self.fraction = fraction
    self.use_metadata = use_metadata
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/forestnet.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            label_map=self.label_map,
            transform=self.train_transform,
            bands=self.bands,
            fraction=self.fraction,
            use_metadata=self.use_metadata,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            label_map=self.label_map,
            transform=self.val_transform,
            bands=self.bands,
            fraction=self.fraction,
            use_metadata=self.use_metadata,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            label_map=self.label_map,
            transform=self.test_transform,
            bands=self.bands,
            fraction=self.fraction,
            use_metadata=self.use_metadata,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            label_map=self.label_map,
            transform=self.predict_transform,
            bands=self.bands,
            fraction=self.fraction,
            use_metadata=self.use_metadata,
        )

terratorch.datamodules.fire_scars

FireScarsDataModule

Bases: GeoDataModule

Geo Fire Scars data module implementation that merges input data with ground truth segmentation masks.

Source code in terratorch/datamodules/fire_scars.py
class FireScarsDataModule(GeoDataModule):
    """Geo Fire Scars data module implementation that merges input data with ground truth segmentation masks."""

    def __init__(self, data_root: str, **kwargs: Any) -> None:
        super().__init__(FireScarsSegmentationMask, 4, 224, 100, 0, **kwargs)
        means = list(MEANS.values())
        stds = list(STDS.values())
        self.train_aug = AugmentationSequential(K.RandomCrop(224, 224), K.Normalize(means, stds))
        self.aug = AugmentationSequential(K.Normalize(means, stds))
        self.data_root = data_root

    def setup(self, stage: str) -> None:
        self.images = FireScarsHLS(
            os.path.join(self.data_root, "training/")
        )
        self.labels = FireScarsSegmentationMask(
            os.path.join(self.data_root, "training/")
        )
        self.dataset = self.images & self.labels
        self.train_aug = AugmentationSequential(K.RandomCrop(224, 224), K.normalize())

        self.images_test = FireScarsHLS(
            os.path.join(self.data_root, "validation/")
        )
        self.labels_test = FireScarsSegmentationMask(
            os.path.join(self.data_root, "validation/")
        )
        self.val_dataset = self.images_test & self.labels_test

        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, self.patch_size, self.batch_size, None)
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(self.val_dataset, self.patch_size, self.patch_size)
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(self.val_dataset, self.patch_size, self.patch_size)
FireScarsNonGeoDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Fire Scars dataset.

Source code in terratorch/datamodules/fire_scars.py
class FireScarsNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Fire Scars dataset."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = FireScarsNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        drop_last: bool = True,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the FireScarsNonGeoDataModule.

        Args:
            data_root (str): Root directory of the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of band names. Defaults to FireScarsNonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
            drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
            no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
            no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
            use_metadata (bool): Whether to return metadata info.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(FireScarsNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = [MEANS[b] for b in bands]
        stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"])
        self.drop_last = drop_last
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(data_root, batch_size=4, num_workers=0, bands=FireScarsNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, drop_last=True, no_data_replace=0, no_label_replace=-1, use_metadata=False, **kwargs)

Initializes the FireScarsNonGeoDataModule.

Parameters:
  • data_root (str) –

    Root directory of the dataset.

  • batch_size (int, default: 4 ) –

    Batch size for DataLoaders. Defaults to 4.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • bands (Sequence[str], default: all_band_names ) –

    List of band names. Defaults to FireScarsNonGeo.all_band_names.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction.

  • drop_last (bool, default: True ) –

    Whether to drop the last incomplete batch. Defaults to True.

  • no_data_replace (float | None, default: 0 ) –

    Replacement value for missing data. Defaults to 0.

  • no_label_replace (int | None, default: -1 ) –

    Replacement value for missing labels. Defaults to -1.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/fire_scars.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = FireScarsNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    drop_last: bool = True,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the FireScarsNonGeoDataModule.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of band names. Defaults to FireScarsNonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction.
        drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
        no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
        no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
        use_metadata (bool): Whether to return metadata info.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(FireScarsNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    means = [MEANS[b] for b in bands]
    stds = [STDS[b] for b in bands]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"])
    self.drop_last = drop_last
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.use_metadata = use_metadata
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/fire_scars.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )

terratorch.datamodules.landslide4sense

Landslide4SenseNonGeoDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Landslide4Sense dataset.

Source code in terratorch/datamodules/landslide4sense.py
class Landslide4SenseNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Landslide4Sense dataset."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = Landslide4SenseNonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the Landslide4SenseNonGeoDataModule.

        Args:
            data_root (str): Root directory of the dataset.
            batch_size (int, optional): Batch size for data loaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of band names to use. Defaults to Landslide4SenseNonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            aug (AugmentationSequential, optional): Augmentation pipeline; if None, applies normalization using computed means and stds.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(Landslide4SenseNonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        self.means = [MEANS[b] for b in bands]
        self.stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = (
            AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
        )

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands
            )
__init__(data_root, batch_size=4, num_workers=0, bands=Landslide4SenseNonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, aug=None, **kwargs)

Initializes the Landslide4SenseNonGeoDataModule.

Parameters:
  • data_root (str) –

    Root directory of the dataset.

  • batch_size (int, default: 4 ) –

    Batch size for data loaders. Defaults to 4.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • bands (Sequence[str], default: all_band_names ) –

    List of band names to use. Defaults to Landslide4SenseNonGeo.all_band_names.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training data.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation data.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing data.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction data.

  • aug (AugmentationSequential, default: None ) –

    Augmentation pipeline; if None, applies normalization using computed means and stds.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/landslide4sense.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = Landslide4SenseNonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    **kwargs: Any,
) -> None:
    """
    Initializes the Landslide4SenseNonGeoDataModule.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int, optional): Batch size for data loaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of band names to use. Defaults to Landslide4SenseNonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        aug (AugmentationSequential, optional): Augmentation pipeline; if None, applies normalization using computed means and stds.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(Landslide4SenseNonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    self.means = [MEANS[b] for b in bands]
    self.stds = [STDS[b] for b in bands]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = (
        AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
    )
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/landslide4sense.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands
        )

terratorch.datamodules.m_eurosat

MEuroSATNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-EuroSAT dataset.

Source code in terratorch/datamodules/m_eurosat.py
class MEuroSATNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-EuroSAT dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MEuroSATNonGeoDataModule for the MEuroSATNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MEuroSATNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs)

Initializes the MEuroSATNonGeoDataModule for the MEuroSATNonGeo dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_eurosat.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MEuroSATNonGeoDataModule for the MEuroSATNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MEuroSATNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_bigearthnet

MBigEarthNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-BigEarthNet dataset.

Source code in terratorch/datamodules/m_bigearthnet.py
class MBigEarthNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-BigEarthNet dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MBigEarthNonGeoDataModule for the M-BigEarthNet dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MBigEarthNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs)

Initializes the MBigEarthNonGeoDataModule for the M-BigEarthNet dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_bigearthnet.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MBigEarthNonGeoDataModule for the M-BigEarthNet dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MBigEarthNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_brick_kiln

MBrickKilnNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-BrickKiln dataset.

Source code in terratorch/datamodules/m_brick_kiln.py
class MBrickKilnNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-BrickKiln dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MBrickKilnNonGeoDataModule for the M-BrickKilnNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MBrickKilnNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs)

Initializes the MBrickKilnNonGeoDataModule for the M-BrickKilnNonGeo dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_brick_kiln.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MBrickKilnNonGeoDataModule for the M-BrickKilnNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MBrickKilnNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_forestnet

MForestNetNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-ForestNet dataset.

Source code in terratorch/datamodules/m_forestnet.py
class MForestNetNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-ForestNet dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        use_metadata: bool = False,  # noqa: FBT002, FBT001
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MForestNetNonGeoDataModule for the MForestNetNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MForestNetNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            use_metadata=use_metadata,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', use_metadata=False, **kwargs)

Initializes the MForestNetNonGeoDataModule for the MForestNetNonGeo dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • use_metadata (bool, default: False ) –

    Whether to return metadata info.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_forestnet.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    use_metadata: bool = False,  # noqa: FBT002, FBT001
    **kwargs: Any,
) -> None:
    """
    Initializes the MForestNetNonGeoDataModule for the MForestNetNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MForestNetNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        use_metadata=use_metadata,
        **kwargs,
    )

terratorch.datamodules.m_so2sat

MSo2SatNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-So2Sat dataset.

Source code in terratorch/datamodules/m_so2sat.py
class MSo2SatNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-So2Sat dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MSo2SatNonGeoDataModule for the MSo2SatNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MSo2SatNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs)

Initializes the MSo2SatNonGeoDataModule for the MSo2SatNonGeo dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_so2sat.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MSo2SatNonGeoDataModule for the MSo2SatNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MSo2SatNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_pv4ger

MPv4gerNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-Pv4ger dataset.

Source code in terratorch/datamodules/m_pv4ger.py
class MPv4gerNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-Pv4ger dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        use_metadata: bool = False,  # noqa: FBT002, FBT001
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MPv4gerNonGeoDataModule for the MPv4gerNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MPv4gerNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            use_metadata=use_metadata,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', use_metadata=False, **kwargs)

Initializes the MPv4gerNonGeoDataModule for the MPv4gerNonGeo dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • use_metadata (bool, default: False ) –

    Whether to return metadata info.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_pv4ger.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    use_metadata: bool = False,  # noqa: FBT002, FBT001
    **kwargs: Any,
) -> None:
    """
    Initializes the MPv4gerNonGeoDataModule for the MPv4gerNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MPv4gerNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        use_metadata=use_metadata,
        **kwargs,
    )

terratorch.datamodules.m_cashew_plantation

MBeninSmallHolderCashewsNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-Cashew Plantation dataset.

Source code in terratorch/datamodules/m_cashew_plantation.py
class MBeninSmallHolderCashewsNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-Cashew Plantation dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        use_metadata: bool = False,  # noqa: FBT002, FBT001
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MBeninSmallHolderCashewsNonGeoDataModule for the M-BeninSmallHolderCashewsNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MBeninSmallHolderCashewsNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            use_metadata=use_metadata,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', use_metadata=False, **kwargs)

Initializes the MBeninSmallHolderCashewsNonGeoDataModule for the M-BeninSmallHolderCashewsNonGeo dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • use_metadata (bool, default: False ) –

    Whether to return metadata info.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_cashew_plantation.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    use_metadata: bool = False,  # noqa: FBT002, FBT001
    **kwargs: Any,
) -> None:
    """
    Initializes the MBeninSmallHolderCashewsNonGeoDataModule for the M-BeninSmallHolderCashewsNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MBeninSmallHolderCashewsNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        use_metadata=use_metadata,
        **kwargs,
    )

terratorch.datamodules.m_nz_cattle

MNzCattleNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-NZCattle dataset.

Source code in terratorch/datamodules/m_nz_cattle.py
class MNzCattleNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-NZCattle dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        use_metadata: bool = False,  # noqa: FBT002, FBT001
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MNzCattleNonGeoDataModule for the MNzCattleNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MNzCattleNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            use_metadata=use_metadata,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', use_metadata=False, **kwargs)

Initializes the MNzCattleNonGeoDataModule for the MNzCattleNonGeo dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • use_metadata (bool, default: False ) –

    Whether to return metadata info.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_nz_cattle.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    use_metadata: bool = False,  # noqa: FBT002, FBT001
    **kwargs: Any,
) -> None:
    """
    Initializes the MNzCattleNonGeoDataModule for the MNzCattleNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MNzCattleNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        use_metadata=use_metadata,
        **kwargs,
    )

terratorch.datamodules.m_chesapeake_landcover

MChesapeakeLandcoverNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-ChesapeakeLandcover dataset.

Source code in terratorch/datamodules/m_chesapeake_landcover.py
class MChesapeakeLandcoverNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-ChesapeakeLandcover dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MChesapeakeLandcoverNonGeoDataModule for the M-BigEarthNet dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MChesapeakeLandcoverNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs)

Initializes the MChesapeakeLandcoverNonGeoDataModule for the M-BigEarthNet dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_chesapeake_landcover.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MChesapeakeLandcoverNonGeoDataModule for the M-BigEarthNet dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MChesapeakeLandcoverNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_pv4ger_seg

MPv4gerSegNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-Pv4gerSeg dataset.

Source code in terratorch/datamodules/m_pv4ger_seg.py
class MPv4gerSegNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-Pv4gerSeg dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        use_metadata: bool = False,  # noqa: FBT002, FBT001
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MPv4gerNonGeoDataModule for the MPv4gerSegNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            use_metadata (bool): Whether to return metadata info.
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MPv4gerSegNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            use_metadata=use_metadata,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', use_metadata=False, **kwargs)

Initializes the MPv4gerNonGeoDataModule for the MPv4gerSegNonGeo dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • use_metadata (bool, default: False ) –

    Whether to return metadata info.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_pv4ger_seg.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    use_metadata: bool = False,  # noqa: FBT002, FBT001
    **kwargs: Any,
) -> None:
    """
    Initializes the MPv4gerNonGeoDataModule for the MPv4gerSegNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        use_metadata (bool): Whether to return metadata info.
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MPv4gerSegNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        use_metadata=use_metadata,
        **kwargs,
    )

terratorch.datamodules.m_SA_crop_type

MSACropTypeNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-SA-CropType dataset.

Source code in terratorch/datamodules/m_SA_crop_type.py
class MSACropTypeNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-SA-CropType dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MSACropTypeNonGeoDataModule for the MSACropTypeNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MSACropTypeNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs)

Initializes the MSACropTypeNonGeoDataModule for the MSACropTypeNonGeo dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_SA_crop_type.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MSACropTypeNonGeoDataModule for the MSACropTypeNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MSACropTypeNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.m_neontree

MNeonTreeNonGeoDataModule

Bases: GeobenchDataModule

NonGeo LightningDataModule implementation for M-NeonTree dataset.

Source code in terratorch/datamodules/m_neontree.py
class MNeonTreeNonGeoDataModule(GeobenchDataModule):
    """NonGeo LightningDataModule implementation for M-NeonTree dataset."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        bands: Sequence[str] | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        partition: str = "default",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MNeonTreeNonGeoDataModule for the MNeonTreeNonGeo dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
            aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
            partition (str, optional): Partition size. Defaults to "default".
            **kwargs (Any): Additional keyword arguments.
        """
        super().__init__(
            MNeonTreeNonGeo,
            MEANS,
            STDS,
            batch_size=batch_size,
            num_workers=num_workers,
            data_root=data_root,
            bands=bands,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            aug=aug,
            partition=partition,
            **kwargs,
        )
__init__(batch_size=8, num_workers=0, data_root='./', bands=None, train_transform=None, val_transform=None, test_transform=None, aug=None, partition='default', **kwargs)

Initializes the MNeonTreeNonGeoDataModule for the MNeonTreeNonGeo dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • bands (Sequence[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing.

  • aug (AugmentationSequential, default: None ) –

    Augmentation/normalization pipeline. Defaults to None.

  • partition (str, default: 'default' ) –

    Partition size. Defaults to "default".

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/m_neontree.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    bands: Sequence[str] | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    partition: str = "default",
    **kwargs: Any,
) -> None:
    """
    Initializes the MNeonTreeNonGeoDataModule for the MNeonTreeNonGeo dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        bands (Sequence[str] | None, optional): List of bands to use. Defaults to None.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing.
        aug (AugmentationSequential, optional): Augmentation/normalization pipeline. Defaults to None.
        partition (str, optional): Partition size. Defaults to "default".
        **kwargs (Any): Additional keyword arguments.
    """
    super().__init__(
        MNeonTreeNonGeo,
        MEANS,
        STDS,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root=data_root,
        bands=bands,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        aug=aug,
        partition=partition,
        **kwargs,
    )

terratorch.datamodules.multi_temporal_crop_classification

MultiTemporalCropClassificationDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for multi-temporal crop classification.

Source code in terratorch/datamodules/multi_temporal_crop_classification.py
class MultiTemporalCropClassificationDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for multi-temporal crop classification."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = MultiTemporalCropClassification.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        drop_last: bool = True,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        expand_temporal_dimension: bool = True,
        reduce_zero_label: bool = True,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the MultiTemporalCropClassificationDataModule for multi-temporal crop classification.

        Args:
            data_root (str): Directory containing the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of bands to use. Defaults to MultiTemporalCropClassification.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            drop_last (bool, optional): Whether to drop the last incomplete batch during training. Defaults to True.
            no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
            no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
            expand_temporal_dimension (bool, optional): Go from shape (time*channels, h, w) to (channels, time, h, w).
                Defaults to True.
            reduce_zero_label (bool, optional): Subtract 1 from all labels. Useful when labels start from 1 instead of the
                expected 0. Defaults to True.
            use_metadata (bool): Whether to return metadata info (time and location).
            **kwargs: Additional keyword arguments.
        """
        super().__init__(MultiTemporalCropClassification, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        self.means = [MEANS[b] for b in bands]
        self.stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = Normalize(self.means, self.stds)
        self.drop_last = drop_last
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.expand_temporal_dimension = expand_temporal_dimension
        self.reduce_zero_label = reduce_zero_label
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension = self.expand_temporal_dimension,
                reduce_zero_label = self.reduce_zero_label,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension = self.expand_temporal_dimension,
                reduce_zero_label = self.reduce_zero_label,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension = self.expand_temporal_dimension,
                reduce_zero_label = self.reduce_zero_label,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                expand_temporal_dimension = self.expand_temporal_dimension,
                reduce_zero_label = self.reduce_zero_label,
                use_metadata=self.use_metadata,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(data_root, batch_size=4, num_workers=0, bands=MultiTemporalCropClassification.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, drop_last=True, no_data_replace=0, no_label_replace=-1, expand_temporal_dimension=True, reduce_zero_label=True, use_metadata=False, **kwargs)

Initializes the MultiTemporalCropClassificationDataModule for multi-temporal crop classification.

Parameters:
  • data_root (str) –

    Directory containing the dataset.

  • batch_size (int, default: 4 ) –

    Batch size for DataLoaders. Defaults to 4.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • bands (Sequence[str], default: all_band_names ) –

    List of bands to use. Defaults to MultiTemporalCropClassification.all_band_names.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training data.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation data.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing data.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction data.

  • drop_last (bool, default: True ) –

    Whether to drop the last incomplete batch during training. Defaults to True.

  • no_data_replace (float | None, default: 0 ) –

    Replacement value for missing data. Defaults to 0.

  • no_label_replace (int | None, default: -1 ) –

    Replacement value for missing labels. Defaults to -1.

  • expand_temporal_dimension (bool, default: True ) –

    Go from shape (time*channels, h, w) to (channels, time, h, w). Defaults to True.

  • reduce_zero_label (bool, default: True ) –

    Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to True.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info (time and location).

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/multi_temporal_crop_classification.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = MultiTemporalCropClassification.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    drop_last: bool = True,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    expand_temporal_dimension: bool = True,
    reduce_zero_label: bool = True,
    use_metadata: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the MultiTemporalCropClassificationDataModule for multi-temporal crop classification.

    Args:
        data_root (str): Directory containing the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of bands to use. Defaults to MultiTemporalCropClassification.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        drop_last (bool, optional): Whether to drop the last incomplete batch during training. Defaults to True.
        no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
        no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
        expand_temporal_dimension (bool, optional): Go from shape (time*channels, h, w) to (channels, time, h, w).
            Defaults to True.
        reduce_zero_label (bool, optional): Subtract 1 from all labels. Useful when labels start from 1 instead of the
            expected 0. Defaults to True.
        use_metadata (bool): Whether to return metadata info (time and location).
        **kwargs: Additional keyword arguments.
    """
    super().__init__(MultiTemporalCropClassification, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    self.means = [MEANS[b] for b in bands]
    self.stds = [STDS[b] for b in bands]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = Normalize(self.means, self.stds)
    self.drop_last = drop_last
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.expand_temporal_dimension = expand_temporal_dimension
    self.reduce_zero_label = reduce_zero_label
    self.use_metadata = use_metadata
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/multi_temporal_crop_classification.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            expand_temporal_dimension = self.expand_temporal_dimension,
            reduce_zero_label = self.reduce_zero_label,
            use_metadata=self.use_metadata,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            expand_temporal_dimension = self.expand_temporal_dimension,
            reduce_zero_label = self.reduce_zero_label,
            use_metadata=self.use_metadata,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            expand_temporal_dimension = self.expand_temporal_dimension,
            reduce_zero_label = self.reduce_zero_label,
            use_metadata=self.use_metadata,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            expand_temporal_dimension = self.expand_temporal_dimension,
            reduce_zero_label = self.reduce_zero_label,
            use_metadata=self.use_metadata,
        )

terratorch.datamodules.open_sentinel_map

OpenSentinelMapDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Open Sentinel Map.

Source code in terratorch/datamodules/open_sentinel_map.py
class OpenSentinelMapDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Open Sentinel Map."""

    def __init__(
        self,
        bands: list[str] | None = None,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
        pad_image: int | None = None,
        truncate_image: int | None = None,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the OpenSentinelMapDataModule for the Open Sentinel Map dataset.

        Args:
            bands (list[str] | None, optional): List of bands to use. Defaults to None.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            spatial_interpolate_and_stack_temporally (bool, optional): If True, the bands are interpolated and concatenated over time.
                Default is True.
            pad_image (int | None, optional): Number of timesteps to pad the time dimension of the image.
                If None, no padding is applied.
            truncate_image (int | None, optional):  Number of timesteps to truncate the time dimension of the image.
                If None, no truncation is performed.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(
            OpenSentinelMap,
            batch_size=batch_size,
            num_workers=num_workers,
            **kwargs,
        )
        self.bands = bands
        self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally
        self.pad_image = pad_image
        self.truncate_image = truncate_image
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.data_root = data_root
        self.kwargs = kwargs

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = OpenSentinelMap(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
                pad_image = self.pad_image,
                truncate_image = self.truncate_image,
                **self.kwargs,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = OpenSentinelMap(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
                pad_image = self.pad_image,
                truncate_image = self.truncate_image,
                **self.kwargs,
            )
        if stage in ["test"]:
            self.test_dataset = OpenSentinelMap(
                split="test",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
                pad_image = self.pad_image,
                truncate_image = self.truncate_image,
                **self.kwargs,
            )
        if stage in ["predict"]:
            self.predict_dataset = OpenSentinelMap(
                split="test",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
                pad_image = self.pad_image,
                truncate_image = self.truncate_image,
                **self.kwargs,
            )
__init__(bands=None, batch_size=8, num_workers=0, data_root='./', train_transform=None, val_transform=None, test_transform=None, predict_transform=None, spatial_interpolate_and_stack_temporally=True, pad_image=None, truncate_image=None, **kwargs)

Initializes the OpenSentinelMapDataModule for the Open Sentinel Map dataset.

Parameters:
  • bands (list[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training data.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation data.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing data.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction data.

  • spatial_interpolate_and_stack_temporally (bool, default: True ) –

    If True, the bands are interpolated and concatenated over time. Default is True.

  • pad_image (int | None, default: None ) –

    Number of timesteps to pad the time dimension of the image. If None, no padding is applied.

  • truncate_image (int | None, default: None ) –

    Number of timesteps to truncate the time dimension of the image. If None, no truncation is performed.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/open_sentinel_map.py
def __init__(
    self,
    bands: list[str] | None = None,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    spatial_interpolate_and_stack_temporally: bool = True,  # noqa: FBT001, FBT002
    pad_image: int | None = None,
    truncate_image: int | None = None,
    **kwargs: Any,
) -> None:
    """
    Initializes the OpenSentinelMapDataModule for the Open Sentinel Map dataset.

    Args:
        bands (list[str] | None, optional): List of bands to use. Defaults to None.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        spatial_interpolate_and_stack_temporally (bool, optional): If True, the bands are interpolated and concatenated over time.
            Default is True.
        pad_image (int | None, optional): Number of timesteps to pad the time dimension of the image.
            If None, no padding is applied.
        truncate_image (int | None, optional):  Number of timesteps to truncate the time dimension of the image.
            If None, no truncation is performed.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(
        OpenSentinelMap,
        batch_size=batch_size,
        num_workers=num_workers,
        **kwargs,
    )
    self.bands = bands
    self.spatial_interpolate_and_stack_temporally = spatial_interpolate_and_stack_temporally
    self.pad_image = pad_image
    self.truncate_image = truncate_image
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.data_root = data_root
    self.kwargs = kwargs
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/open_sentinel_map.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = OpenSentinelMap(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
            pad_image = self.pad_image,
            truncate_image = self.truncate_image,
            **self.kwargs,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = OpenSentinelMap(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
            pad_image = self.pad_image,
            truncate_image = self.truncate_image,
            **self.kwargs,
        )
    if stage in ["test"]:
        self.test_dataset = OpenSentinelMap(
            split="test",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
            pad_image = self.pad_image,
            truncate_image = self.truncate_image,
            **self.kwargs,
        )
    if stage in ["predict"]:
        self.predict_dataset = OpenSentinelMap(
            split="test",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            spatial_interpolate_and_stack_temporally = self.spatial_interpolate_and_stack_temporally,
            pad_image = self.pad_image,
            truncate_image = self.truncate_image,
            **self.kwargs,
        )

terratorch.datamodules.openearthmap

OpenEarthMapNonGeoDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Open Earth Map.

Source code in terratorch/datamodules/openearthmap.py
class OpenEarthMapNonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Open Earth Map."""

    def __init__(
        self, 
        batch_size: int = 8, 
        num_workers: int = 0, 
        data_root: str = "./",
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        aug: AugmentationSequential = None,
        **kwargs: Any
    ) -> None:
        """
        Initializes the OpenEarthMapNonGeoDataModule for the Open Earth Map dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            aug (AugmentationSequential, optional): Augmentation pipeline; if None, defaults to normalization using computed means and stds.
            **kwargs: Additional keyword arguments. Can include 'bands' (list[str]) to specify the bands; defaults to OpenEarthMapNonGeo.all_band_names if not provided.
        """
        super().__init__(OpenEarthMapNonGeo, batch_size, num_workers, **kwargs)

        bands = kwargs.get("bands", OpenEarthMapNonGeo.all_band_names)
        self.means = torch.tensor([MEANS[b] for b in bands])
        self.stds = torch.tensor([STDS[b] for b in bands])
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.data_root = data_root
        self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(  
                split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val", data_root=self.data_root, transform=self.val_transform, **self.kwargs
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",data_root=self.data_root, transform=self.predict_transform, **self.kwargs
            )
__init__(batch_size=8, num_workers=0, data_root='./', train_transform=None, val_transform=None, test_transform=None, predict_transform=None, aug=None, **kwargs)

Initializes the OpenEarthMapNonGeoDataModule for the Open Earth Map dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training data.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation data.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for test data.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction data.

  • aug (AugmentationSequential, default: None ) –

    Augmentation pipeline; if None, defaults to normalization using computed means and stds.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments. Can include 'bands' (list[str]) to specify the bands; defaults to OpenEarthMapNonGeo.all_band_names if not provided.

Source code in terratorch/datamodules/openearthmap.py
def __init__(
    self, 
    batch_size: int = 8, 
    num_workers: int = 0, 
    data_root: str = "./",
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    aug: AugmentationSequential = None,
    **kwargs: Any
) -> None:
    """
    Initializes the OpenEarthMapNonGeoDataModule for the Open Earth Map dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        aug (AugmentationSequential, optional): Augmentation pipeline; if None, defaults to normalization using computed means and stds.
        **kwargs: Additional keyword arguments. Can include 'bands' (list[str]) to specify the bands; defaults to OpenEarthMapNonGeo.all_band_names if not provided.
    """
    super().__init__(OpenEarthMapNonGeo, batch_size, num_workers, **kwargs)

    bands = kwargs.get("bands", OpenEarthMapNonGeo.all_band_names)
    self.means = torch.tensor([MEANS[b] for b in bands])
    self.stds = torch.tensor([STDS[b] for b in bands])
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.data_root = data_root
    self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/openearthmap.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(  
            split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val", data_root=self.data_root, transform=self.val_transform, **self.kwargs
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",data_root=self.data_root, transform=self.predict_transform, **self.kwargs
        )

terratorch.datamodules.pastis

PASTISDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for PASTIS.

Source code in terratorch/datamodules/pastis.py
class PASTISDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for PASTIS."""

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        truncate_image: int | None = None,
        pad_image: int | None = None,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the PASTISDataModule for the PASTIS dataset.

        Args:
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Directory containing the dataset. Defaults to "./".
            truncate_image (int, optional): Truncate the time dimension of the image to 
                a specified number of timesteps. If None, no truncation is performed.
            pad_image (int, optional): Pad the time dimension of the image to a specified 
                number of timesteps. If None, no padding is applied.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(
            PASTIS,
            batch_size=batch_size,
            num_workers=num_workers,
            **kwargs,
        )
        self.truncate_image = truncate_image
        self.pad_image = pad_image
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.data_root = data_root
        self.kwargs = kwargs

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = PASTIS(
                folds=[1, 2, 3],
                data_root=self.data_root,
                transform=self.train_transform,
                truncate_image=self.truncate_image,
                pad_image=self.pad_image,
                **self.kwargs,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = PASTIS(
                folds=[4],
                data_root=self.data_root,
                transform=self.val_transform,
                truncate_image=self.truncate_image,
                pad_image=self.pad_image,
                **self.kwargs,
            )
        if stage in ["test"]:
            self.test_dataset = PASTIS(
                folds=[5],
                data_root=self.data_root,
                transform=self.test_transform,
                truncate_image=self.truncate_image,
                pad_image=self.pad_image,
                **self.kwargs,
            )
        if stage in ["predict"]:
            self.predict_dataset = PASTIS(
                folds=[5],
                data_root=self.data_root,
                transform=self.predict_transform,
                truncate_image=self.truncate_image,
                pad_image=self.pad_image,
                **self.kwargs,
            )
__init__(batch_size=8, num_workers=0, data_root='./', truncate_image=None, pad_image=None, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, **kwargs)

Initializes the PASTISDataModule for the PASTIS dataset.

Parameters:
  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Directory containing the dataset. Defaults to "./".

  • truncate_image (int, default: None ) –

    Truncate the time dimension of the image to a specified number of timesteps. If None, no truncation is performed.

  • pad_image (int, default: None ) –

    Pad the time dimension of the image to a specified number of timesteps. If None, no padding is applied.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training data.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation data.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for testing data.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction data.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/pastis.py
def __init__(
    self,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    truncate_image: int | None = None,
    pad_image: int | None = None,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    **kwargs: Any,
) -> None:
    """
    Initializes the PASTISDataModule for the PASTIS dataset.

    Args:
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Directory containing the dataset. Defaults to "./".
        truncate_image (int, optional): Truncate the time dimension of the image to 
            a specified number of timesteps. If None, no truncation is performed.
        pad_image (int, optional): Pad the time dimension of the image to a specified 
            number of timesteps. If None, no padding is applied.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for testing data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(
        PASTIS,
        batch_size=batch_size,
        num_workers=num_workers,
        **kwargs,
    )
    self.truncate_image = truncate_image
    self.pad_image = pad_image
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.data_root = data_root
    self.kwargs = kwargs
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/pastis.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = PASTIS(
            folds=[1, 2, 3],
            data_root=self.data_root,
            transform=self.train_transform,
            truncate_image=self.truncate_image,
            pad_image=self.pad_image,
            **self.kwargs,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = PASTIS(
            folds=[4],
            data_root=self.data_root,
            transform=self.val_transform,
            truncate_image=self.truncate_image,
            pad_image=self.pad_image,
            **self.kwargs,
        )
    if stage in ["test"]:
        self.test_dataset = PASTIS(
            folds=[5],
            data_root=self.data_root,
            transform=self.test_transform,
            truncate_image=self.truncate_image,
            pad_image=self.pad_image,
            **self.kwargs,
        )
    if stage in ["predict"]:
        self.predict_dataset = PASTIS(
            folds=[5],
            data_root=self.data_root,
            transform=self.predict_transform,
            truncate_image=self.truncate_image,
            pad_image=self.pad_image,
            **self.kwargs,
        )

terratorch.datamodules.sen1floods11

Sen1Floods11NonGeoDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Fire Scars.

Source code in terratorch/datamodules/sen1floods11.py
class Sen1Floods11NonGeoDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Fire Scars."""

    def __init__(
        self,
        data_root: str,
        batch_size: int = 4,
        num_workers: int = 0,
        bands: Sequence[str] = Sen1Floods11NonGeo.all_band_names,
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        drop_last: bool = True,
        constant_scale: float = 0.0001,
        no_data_replace: float | None = 0,
        no_label_replace: int | None = -1,
        use_metadata: bool = False,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the Sen1Floods11NonGeoDataModule.

        Args:
            data_root (str): Root directory of the dataset.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            bands (Sequence[str], optional): List of bands to use. Defaults to Sen1Floods11NonGeo.all_band_names.
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
            constant_scale (float, optional): Scale constant applied to the dataset. Defaults to 0.0001.
            no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
            no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
            use_metadata (bool): Whether to return metadata info (time and location).
            **kwargs: Additional keyword arguments.
        """
        super().__init__(Sen1Floods11NonGeo, batch_size, num_workers, **kwargs)
        self.data_root = data_root

        means = [MEANS[b] for b in bands]
        stds = [STDS[b] for b in bands]
        self.bands = bands
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"])
        self.drop_last = drop_last
        self.constant_scale = constant_scale
        self.no_data_replace = no_data_replace
        self.no_label_replace = no_label_replace
        self.use_metadata = use_metadata

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = self.dataset_class(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                constant_scale=self.constant_scale,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = self.dataset_class(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                constant_scale=self.constant_scale,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["test"]:
            self.test_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                constant_scale=self.constant_scale,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )
        if stage in ["predict"]:
            self.predict_dataset = self.dataset_class(
                split="test",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                constant_scale=self.constant_scale,
                no_data_replace=self.no_data_replace,
                no_label_replace=self.no_label_replace,
                use_metadata=self.use_metadata,
            )

    def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
        """Implement one or more PyTorch DataLoaders.

        Args:
            split: Either 'train', 'val', 'test', or 'predict'.

        Returns:
            A collection of data loaders specifying samples.

        Raises:
            MisconfigurationException: If :meth:`setup` does not define a
                dataset or sampler, or if the dataset or sampler has length 0.
        """
        dataset = self._valid_attribute(f"{split}_dataset", "dataset")
        batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=split == "train",
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            drop_last=split == "train" and self.drop_last,
        )
__init__(data_root, batch_size=4, num_workers=0, bands=Sen1Floods11NonGeo.all_band_names, train_transform=None, val_transform=None, test_transform=None, predict_transform=None, drop_last=True, constant_scale=0.0001, no_data_replace=0, no_label_replace=-1, use_metadata=False, **kwargs)

Initializes the Sen1Floods11NonGeoDataModule.

Parameters:
  • data_root (str) –

    Root directory of the dataset.

  • batch_size (int, default: 4 ) –

    Batch size for DataLoaders. Defaults to 4.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • bands (Sequence[str], default: all_band_names ) –

    List of bands to use. Defaults to Sen1Floods11NonGeo.all_band_names.

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training data.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation data.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for test data.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction data.

  • drop_last (bool, default: True ) –

    Whether to drop the last incomplete batch. Defaults to True.

  • constant_scale (float, default: 0.0001 ) –

    Scale constant applied to the dataset. Defaults to 0.0001.

  • no_data_replace (float | None, default: 0 ) –

    Replacement value for missing data. Defaults to 0.

  • no_label_replace (int | None, default: -1 ) –

    Replacement value for missing labels. Defaults to -1.

  • use_metadata (bool, default: False ) –

    Whether to return metadata info (time and location).

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/sen1floods11.py
def __init__(
    self,
    data_root: str,
    batch_size: int = 4,
    num_workers: int = 0,
    bands: Sequence[str] = Sen1Floods11NonGeo.all_band_names,
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    drop_last: bool = True,
    constant_scale: float = 0.0001,
    no_data_replace: float | None = 0,
    no_label_replace: int | None = -1,
    use_metadata: bool = False,
    **kwargs: Any,
) -> None:
    """
    Initializes the Sen1Floods11NonGeoDataModule.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 4.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        bands (Sequence[str], optional): List of bands to use. Defaults to Sen1Floods11NonGeo.all_band_names.
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
        constant_scale (float, optional): Scale constant applied to the dataset. Defaults to 0.0001.
        no_data_replace (float | None, optional): Replacement value for missing data. Defaults to 0.
        no_label_replace (int | None, optional): Replacement value for missing labels. Defaults to -1.
        use_metadata (bool): Whether to return metadata info (time and location).
        **kwargs: Additional keyword arguments.
    """
    super().__init__(Sen1Floods11NonGeo, batch_size, num_workers, **kwargs)
    self.data_root = data_root

    means = [MEANS[b] for b in bands]
    stds = [STDS[b] for b in bands]
    self.bands = bands
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.aug = AugmentationSequential(K.Normalize(means, stds), data_keys=["image"])
    self.drop_last = drop_last
    self.constant_scale = constant_scale
    self.no_data_replace = no_data_replace
    self.no_label_replace = no_label_replace
    self.use_metadata = use_metadata
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/sen1floods11.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = self.dataset_class(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            constant_scale=self.constant_scale,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = self.dataset_class(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            constant_scale=self.constant_scale,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["test"]:
        self.test_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            constant_scale=self.constant_scale,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )
    if stage in ["predict"]:
        self.predict_dataset = self.dataset_class(
            split="test",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            constant_scale=self.constant_scale,
            no_data_replace=self.no_data_replace,
            no_label_replace=self.no_label_replace,
            use_metadata=self.use_metadata,
        )

terratorch.datamodules.sen4agrinet

Sen4AgriNetDataModule

Bases: NonGeoDataModule

NonGeo LightningDataModule implementation for Sen4AgriNet.

Source code in terratorch/datamodules/sen4agrinet.py
class Sen4AgriNetDataModule(NonGeoDataModule):
    """NonGeo LightningDataModule implementation for Sen4AgriNet."""

    def __init__(
        self,
        bands: list[str] | None = None,
        batch_size: int = 8,
        num_workers: int = 0,
        data_root: str = "./",
        train_transform: A.Compose | None | list[A.BasicTransform] = None,
        val_transform: A.Compose | None | list[A.BasicTransform] = None,
        test_transform: A.Compose | None | list[A.BasicTransform] = None,
        predict_transform: A.Compose | None | list[A.BasicTransform] = None,
        seed: int = 42,
        scenario: str = "random",
        requires_norm: bool = True,
        binary_labels: bool = False,
        linear_encoder: dict = None,
        **kwargs: Any,
    ) -> None:
        """
        Initializes the Sen4AgriNetDataModule for the Sen4AgriNet dataset.

        Args:
            bands (list[str] | None, optional): List of bands to use. Defaults to None.
            batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.
            data_root (str, optional): Root directory of the dataset. Defaults to "./".
            train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
            val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
            test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
            predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
            seed (int, optional): Random seed for reproducibility. Defaults to 42.
            scenario (str): Defines the splitting scenario to use. Options are:
                - 'random': Random split of the data.
                - 'spatial': Split by geographical regions (Catalonia and France).
                - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).
            requires_norm (bool, optional): Whether normalization is required. Defaults to True.
            binary_labels (bool, optional): Whether to use binary labels. Defaults to False.
            linear_encoder (dict, optional): Mapping for label encoding. Defaults to None.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(
            Sen4AgriNet,
            batch_size=batch_size,
            num_workers=num_workers,
            **kwargs,
        )
        self.bands = bands
        self.seed = seed
        self.train_transform = wrap_in_compose_is_list(train_transform)
        self.val_transform = wrap_in_compose_is_list(val_transform)
        self.test_transform = wrap_in_compose_is_list(test_transform)
        self.predict_transform = wrap_in_compose_is_list(predict_transform)
        self.data_root = data_root
        self.scenario = scenario
        self.requires_norm = requires_norm
        self.binary_labels = binary_labels
        self.linear_encoder = linear_encoder
        self.kwargs = kwargs

    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either fit, validate, test, or predict.
        """
        if stage in ["fit"]:
            self.train_dataset = Sen4AgriNet(
                split="train",
                data_root=self.data_root,
                transform=self.train_transform,
                bands=self.bands,
                seed=self.seed,
                scenario=self.scenario,
                requires_norm=self.requires_norm,
                binary_labels=self.binary_labels,
                linear_encoder=self.linear_encoder,
                **self.kwargs,
            )
        if stage in ["fit", "validate"]:
            self.val_dataset = Sen4AgriNet(
                split="val",
                data_root=self.data_root,
                transform=self.val_transform,
                bands=self.bands,
                seed=self.seed,
                scenario=self.scenario,
                requires_norm=self.requires_norm,
                binary_labels=self.binary_labels,
                linear_encoder=self.linear_encoder,
                **self.kwargs,
            )
        if stage in ["test"]:
            self.test_dataset = Sen4AgriNet(
                split="test",
                data_root=self.data_root,
                transform=self.test_transform,
                bands=self.bands,
                seed=self.seed,
                scenario=self.scenario,
                requires_norm=self.requires_norm,
                binary_labels=self.binary_labels,
                linear_encoder=self.linear_encoder,
                **self.kwargs,
            )
        if stage in ["predict"]:
            self.predict_dataset = Sen4AgriNet(
                split="test",
                data_root=self.data_root,
                transform=self.predict_transform,
                bands=self.bands,
                seed=self.seed,
                scenario=self.scenario,
                requires_norm=self.requires_norm,
                binary_labels=self.binary_labels,
                linear_encoder=self.linear_encoder,
                **self.kwargs,
            )
__init__(bands=None, batch_size=8, num_workers=0, data_root='./', train_transform=None, val_transform=None, test_transform=None, predict_transform=None, seed=42, scenario='random', requires_norm=True, binary_labels=False, linear_encoder=None, **kwargs)

Initializes the Sen4AgriNetDataModule for the Sen4AgriNet dataset.

Parameters:
  • bands (list[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • batch_size (int, default: 8 ) –

    Batch size for DataLoaders. Defaults to 8.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. Defaults to 0.

  • data_root (str, default: './' ) –

    Root directory of the dataset. Defaults to "./".

  • train_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for training data.

  • val_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for validation data.

  • test_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for test data.

  • predict_transform (Compose | None | list[BasicTransform], default: None ) –

    Transformations for prediction data.

  • seed (int, default: 42 ) –

    Random seed for reproducibility. Defaults to 42.

  • scenario (str, default: 'random' ) –

    Defines the splitting scenario to use. Options are: - 'random': Random split of the data. - 'spatial': Split by geographical regions (Catalonia and France). - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).

  • requires_norm (bool, default: True ) –

    Whether normalization is required. Defaults to True.

  • binary_labels (bool, default: False ) –

    Whether to use binary labels. Defaults to False.

  • linear_encoder (dict, default: None ) –

    Mapping for label encoding. Defaults to None.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments.

Source code in terratorch/datamodules/sen4agrinet.py
def __init__(
    self,
    bands: list[str] | None = None,
    batch_size: int = 8,
    num_workers: int = 0,
    data_root: str = "./",
    train_transform: A.Compose | None | list[A.BasicTransform] = None,
    val_transform: A.Compose | None | list[A.BasicTransform] = None,
    test_transform: A.Compose | None | list[A.BasicTransform] = None,
    predict_transform: A.Compose | None | list[A.BasicTransform] = None,
    seed: int = 42,
    scenario: str = "random",
    requires_norm: bool = True,
    binary_labels: bool = False,
    linear_encoder: dict = None,
    **kwargs: Any,
) -> None:
    """
    Initializes the Sen4AgriNetDataModule for the Sen4AgriNet dataset.

    Args:
        bands (list[str] | None, optional): List of bands to use. Defaults to None.
        batch_size (int, optional): Batch size for DataLoaders. Defaults to 8.
        num_workers (int, optional): Number of workers for data loading. Defaults to 0.
        data_root (str, optional): Root directory of the dataset. Defaults to "./".
        train_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for training data.
        val_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for validation data.
        test_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for test data.
        predict_transform (A.Compose | None | list[A.BasicTransform], optional): Transformations for prediction data.
        seed (int, optional): Random seed for reproducibility. Defaults to 42.
        scenario (str): Defines the splitting scenario to use. Options are:
            - 'random': Random split of the data.
            - 'spatial': Split by geographical regions (Catalonia and France).
            - 'spatio-temporal': Split by region and year (France 2019 and Catalonia 2020).
        requires_norm (bool, optional): Whether normalization is required. Defaults to True.
        binary_labels (bool, optional): Whether to use binary labels. Defaults to False.
        linear_encoder (dict, optional): Mapping for label encoding. Defaults to None.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(
        Sen4AgriNet,
        batch_size=batch_size,
        num_workers=num_workers,
        **kwargs,
    )
    self.bands = bands
    self.seed = seed
    self.train_transform = wrap_in_compose_is_list(train_transform)
    self.val_transform = wrap_in_compose_is_list(val_transform)
    self.test_transform = wrap_in_compose_is_list(test_transform)
    self.predict_transform = wrap_in_compose_is_list(predict_transform)
    self.data_root = data_root
    self.scenario = scenario
    self.requires_norm = requires_norm
    self.binary_labels = binary_labels
    self.linear_encoder = linear_encoder
    self.kwargs = kwargs
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, validate, test, or predict.

Source code in terratorch/datamodules/sen4agrinet.py
def setup(self, stage: str) -> None:
    """Set up datasets.

    Args:
        stage: Either fit, validate, test, or predict.
    """
    if stage in ["fit"]:
        self.train_dataset = Sen4AgriNet(
            split="train",
            data_root=self.data_root,
            transform=self.train_transform,
            bands=self.bands,
            seed=self.seed,
            scenario=self.scenario,
            requires_norm=self.requires_norm,
            binary_labels=self.binary_labels,
            linear_encoder=self.linear_encoder,
            **self.kwargs,
        )
    if stage in ["fit", "validate"]:
        self.val_dataset = Sen4AgriNet(
            split="val",
            data_root=self.data_root,
            transform=self.val_transform,
            bands=self.bands,
            seed=self.seed,
            scenario=self.scenario,
            requires_norm=self.requires_norm,
            binary_labels=self.binary_labels,
            linear_encoder=self.linear_encoder,
            **self.kwargs,
        )
    if stage in ["test"]:
        self.test_dataset = Sen4AgriNet(
            split="test",
            data_root=self.data_root,
            transform=self.test_transform,
            bands=self.bands,
            seed=self.seed,
            scenario=self.scenario,
            requires_norm=self.requires_norm,
            binary_labels=self.binary_labels,
            linear_encoder=self.linear_encoder,
            **self.kwargs,
        )
    if stage in ["predict"]:
        self.predict_dataset = Sen4AgriNet(
            split="test",
            data_root=self.data_root,
            transform=self.predict_transform,
            bands=self.bands,
            seed=self.seed,
            scenario=self.scenario,
            requires_norm=self.requires_norm,
            binary_labels=self.binary_labels,
            linear_encoder=self.linear_encoder,
            **self.kwargs,
        )

terratorch.datamodules.sen4map

Sen4MapLucasDataModule

Bases: LightningDataModule

NonGeo LightningDataModule implementation for Sen4map.

Source code in terratorch/datamodules/sen4map.py
class Sen4MapLucasDataModule(pl.LightningDataModule):
    """NonGeo LightningDataModule implementation for Sen4map."""

    def __init__(
            self, 
            batch_size,
            num_workers,
            prefetch_factor = 0,
            # dataset_bands:list[HLSBands|int] = None,
            # input_bands:list[HLSBands|int] = None,
            train_hdf5_path = None,
            train_hdf5_keys_path = None,
            test_hdf5_path = None,
            test_hdf5_keys_path = None,
            val_hdf5_path = None,
            val_hdf5_keys_path = None,
            **kwargs
            ):
        """
        Initializes the Sen4MapLucasDataModule for handling Sen4Map monthly composites.

        Args:
            batch_size (int): Batch size for DataLoaders.
            num_workers (int): Number of worker processes for data loading.
            prefetch_factor (int, optional): Number of samples to prefetch per worker. Defaults to 0.
            train_hdf5_path (str, optional): Path to the training HDF5 file.
            train_hdf5_keys_path (str, optional): Path to the training HDF5 keys file.
            test_hdf5_path (str, optional): Path to the testing HDF5 file.
            test_hdf5_keys_path (str, optional): Path to the testing HDF5 keys file.
            val_hdf5_path (str, optional): Path to the validation HDF5 file.
            val_hdf5_keys_path (str, optional): Path to the validation HDF5 keys file.
            train_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated train keys.
            test_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated test keys.
            val_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated validation keys.
            shuffle (bool, optional): Global shuffle flag.
            train_shuffle (bool, optional): Shuffle flag for training data; defaults to global shuffle if unset.
            val_shuffle (bool, optional): Shuffle flag for validation data.
            test_shuffle (bool, optional): Shuffle flag for test data.
            train_data_fraction (float, optional): Fraction of training data to use. Defaults to 1.0.
            val_data_fraction (float, optional): Fraction of validation data to use. Defaults to 1.0.
            test_data_fraction (float, optional): Fraction of test data to use. Defaults to 1.0.
            all_hdf5_data_path (str, optional): General HDF5 data path for all splits. If provided, overrides specific paths.
            resize (bool, optional): Whether to resize images. Defaults to False.
            resize_to (int or tuple, optional): Target size for resizing images.
            resize_interpolation (str, optional): Interpolation mode for resizing ('bilinear', 'bicubic', etc.).
            resize_antialiasing (bool, optional): Whether to apply antialiasing during resizing. Defaults to True.
            **kwargs: Additional keyword arguments.
        """
        self.prepare_data_per_node = False
        self._log_hyperparams = None
        self.allow_zero_length_dataloader_with_multiple_devices = False

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.prefetch_factor = prefetch_factor

        self.train_hdf5_path = train_hdf5_path
        self.test_hdf5_path = test_hdf5_path
        self.val_hdf5_path = val_hdf5_path

        self.train_hdf5_keys_path = train_hdf5_keys_path
        self.test_hdf5_keys_path = test_hdf5_keys_path
        self.val_hdf5_keys_path = val_hdf5_keys_path

        if train_hdf5_path and not train_hdf5_keys_path: print(f"Train dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")
        if test_hdf5_path and not test_hdf5_keys_path: print(f"Test dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")
        if val_hdf5_path and not val_hdf5_keys_path: print(f"Val dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")

        self.train_hdf5_keys_save_path = kwargs.pop("train_hdf5_keys_save_path", None)
        self.test_hdf5_keys_save_path = kwargs.pop("test_hdf5_keys_save_path", None)
        self.val_hdf5_keys_save_path = kwargs.pop("val_hdf5_keys_save_path", None)

        self.shuffle = kwargs.pop("shuffle", None)
        self.train_shuffle = kwargs.pop("train_shuffle", None) or self.shuffle
        self.val_shuffle = kwargs.pop("val_shuffle", None)
        self.test_shuffle = kwargs.pop("test_shuffle", None)

        self.train_data_fraction = kwargs.pop("train_data_fraction", 1.0)
        self.val_data_fraction = kwargs.pop("val_data_fraction", 1.0)
        self.test_data_fraction = kwargs.pop("test_data_fraction", 1.0)

        if self.train_data_fraction != 1.0  and  not train_hdf5_keys_path: raise ValueError(f"train_data_fraction provided as non-unity but train_hdf5_keys_path is unset.")
        if self.val_data_fraction != 1.0  and  not val_hdf5_keys_path: raise ValueError(f"val_data_fraction provided as non-unity but val_hdf5_keys_path is unset.")
        if self.test_data_fraction != 1.0  and  not test_hdf5_keys_path: raise ValueError(f"test_data_fraction provided as non-unity but test_hdf5_keys_path is unset.")

        all_hdf5_data_path = kwargs.pop("all_hdf5_data_path", None)
        if all_hdf5_data_path is not None:
            print(f"all_hdf5_data_path provided, will be interpreted as the general data path for all splits.\nKeys in provided train_hdf5_keys_path assumed to encompass all keys for entire data. Validation and Test keys will be subtracted from Train keys.")
            if self.train_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific train_hdf5_path, remove the train_hdf5_path")
            if self.val_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific val_hdf5_path, remove the val_hdf5_path")
            if self.test_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific test_hdf5_path, remove the test_hdf5_path")
            self.train_hdf5_path = all_hdf5_data_path
            self.val_hdf5_path = all_hdf5_data_path
            self.test_hdf5_path = all_hdf5_data_path
            self.reduce_train_keys = True
        else:
            self.reduce_train_keys = False

        self.resize = kwargs.pop("resize", False)
        self.resize_to = kwargs.pop("resize_to", None)
        if self.resize and self.resize_to is None:
            raise ValueError(f"Config provided resize as True, but resize_to parameter not given")
        self.resize_interpolation = kwargs.pop("resize_interpolation", None)
        if self.resize and self.resize_interpolation is None:
            print(f"Config provided resize as True, but resize_interpolation mode not given. Will assume default bilinear")
            self.resize_interpolation = "bilinear"
        interpolation_dict = {
            "bilinear": InterpolationMode.BILINEAR,
            "bicubic": InterpolationMode.BICUBIC,
            "nearest": InterpolationMode.NEAREST,
            "nearest_exact": InterpolationMode.NEAREST_EXACT
        }
        if self.resize:
            if self.resize_interpolation not in interpolation_dict.keys():
                raise ValueError(f"resize_interpolation provided as {self.resize_interpolation}, but valid options are: {interpolation_dict.keys()}")
            self.resize_interpolation = interpolation_dict[self.resize_interpolation]
        self.resize_antialiasing = kwargs.pop("resize_antialiasing", True)

        self.kwargs = kwargs

    def _load_hdf5_keys_from_path(self, path, fraction=1.0):
        if path is None: return None
        with open(path, "rb") as f:
            keys = pickle.load(f)
            return keys[:int(fraction*len(keys))]

    def setup(self, stage: str):
        """Set up datasets.

        Args:
            stage: Either fit, test.
        """
        if stage == "fit":
            train_keys = self._load_hdf5_keys_from_path(self.train_hdf5_keys_path, fraction=self.train_data_fraction)
            val_keys = self._load_hdf5_keys_from_path(self.val_hdf5_keys_path, fraction=self.val_data_fraction)
            if self.reduce_train_keys:
                test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
                train_keys = list(set(train_keys) - set(val_keys) - set(test_keys))
            train_file = h5py.File(self.train_hdf5_path, 'r')
            self.lucasS2_train = Sen4MapDatasetMonthlyComposites(
                train_file, 
                h5data_keys = train_keys, 
                resize = self.resize,
                resize_to = self.resize_to,
                resize_interpolation = self.resize_interpolation,
                resize_antialiasing = self.resize_antialiasing,
                save_keys_path = self.train_hdf5_keys_save_path,
                **self.kwargs
            )
            val_file = h5py.File(self.val_hdf5_path, 'r')
            self.lucasS2_val = Sen4MapDatasetMonthlyComposites(
                val_file, 
                h5data_keys=val_keys, 
                resize = self.resize,
                resize_to = self.resize_to,
                resize_interpolation = self.resize_interpolation,
                resize_antialiasing = self.resize_antialiasing,
                save_keys_path = self.val_hdf5_keys_save_path,
                **self.kwargs
            )
        if stage == "test":
            test_file = h5py.File(self.test_hdf5_path, 'r')
            test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
            self.lucasS2_test = Sen4MapDatasetMonthlyComposites(
                test_file, 
                h5data_keys=test_keys, 
                resize = self.resize,
                resize_to = self.resize_to,
                resize_interpolation = self.resize_interpolation,
                resize_antialiasing = self.resize_antialiasing,
                save_keys_path = self.test_hdf5_keys_save_path,
                **self.kwargs
            )

    def train_dataloader(self):
        return DataLoader(self.lucasS2_train, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.train_shuffle)

    def val_dataloader(self):
        return DataLoader(self.lucasS2_val, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.val_shuffle)

    def test_dataloader(self):
        return DataLoader(self.lucasS2_test, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.test_shuffle)
__init__(batch_size, num_workers, prefetch_factor=0, train_hdf5_path=None, train_hdf5_keys_path=None, test_hdf5_path=None, test_hdf5_keys_path=None, val_hdf5_path=None, val_hdf5_keys_path=None, **kwargs)

Initializes the Sen4MapLucasDataModule for handling Sen4Map monthly composites.

Parameters:
  • batch_size (int) –

    Batch size for DataLoaders.

  • num_workers (int) –

    Number of worker processes for data loading.

  • prefetch_factor (int, default: 0 ) –

    Number of samples to prefetch per worker. Defaults to 0.

  • train_hdf5_path (str, default: None ) –

    Path to the training HDF5 file.

  • train_hdf5_keys_path (str, default: None ) –

    Path to the training HDF5 keys file.

  • test_hdf5_path (str, default: None ) –

    Path to the testing HDF5 file.

  • test_hdf5_keys_path (str, default: None ) –

    Path to the testing HDF5 keys file.

  • val_hdf5_path (str, default: None ) –

    Path to the validation HDF5 file.

  • val_hdf5_keys_path (str, default: None ) –

    Path to the validation HDF5 keys file.

  • train_hdf5_keys_save_path (str) –

    (from kwargs) Path to save generated train keys.

  • test_hdf5_keys_save_path (str) –

    (from kwargs) Path to save generated test keys.

  • val_hdf5_keys_save_path (str) –

    (from kwargs) Path to save generated validation keys.

  • shuffle (bool) –

    Global shuffle flag.

  • train_shuffle (bool) –

    Shuffle flag for training data; defaults to global shuffle if unset.

  • val_shuffle (bool) –

    Shuffle flag for validation data.

  • test_shuffle (bool) –

    Shuffle flag for test data.

  • train_data_fraction (float) –

    Fraction of training data to use. Defaults to 1.0.

  • val_data_fraction (float) –

    Fraction of validation data to use. Defaults to 1.0.

  • test_data_fraction (float) –

    Fraction of test data to use. Defaults to 1.0.

  • all_hdf5_data_path (str) –

    General HDF5 data path for all splits. If provided, overrides specific paths.

  • resize (bool) –

    Whether to resize images. Defaults to False.

  • resize_to (int or tuple) –

    Target size for resizing images.

  • resize_interpolation (str) –

    Interpolation mode for resizing ('bilinear', 'bicubic', etc.).

  • resize_antialiasing (bool) –

    Whether to apply antialiasing during resizing. Defaults to True.

  • **kwargs

    Additional keyword arguments.

Source code in terratorch/datamodules/sen4map.py
def __init__(
        self, 
        batch_size,
        num_workers,
        prefetch_factor = 0,
        # dataset_bands:list[HLSBands|int] = None,
        # input_bands:list[HLSBands|int] = None,
        train_hdf5_path = None,
        train_hdf5_keys_path = None,
        test_hdf5_path = None,
        test_hdf5_keys_path = None,
        val_hdf5_path = None,
        val_hdf5_keys_path = None,
        **kwargs
        ):
    """
    Initializes the Sen4MapLucasDataModule for handling Sen4Map monthly composites.

    Args:
        batch_size (int): Batch size for DataLoaders.
        num_workers (int): Number of worker processes for data loading.
        prefetch_factor (int, optional): Number of samples to prefetch per worker. Defaults to 0.
        train_hdf5_path (str, optional): Path to the training HDF5 file.
        train_hdf5_keys_path (str, optional): Path to the training HDF5 keys file.
        test_hdf5_path (str, optional): Path to the testing HDF5 file.
        test_hdf5_keys_path (str, optional): Path to the testing HDF5 keys file.
        val_hdf5_path (str, optional): Path to the validation HDF5 file.
        val_hdf5_keys_path (str, optional): Path to the validation HDF5 keys file.
        train_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated train keys.
        test_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated test keys.
        val_hdf5_keys_save_path (str, optional): (from kwargs) Path to save generated validation keys.
        shuffle (bool, optional): Global shuffle flag.
        train_shuffle (bool, optional): Shuffle flag for training data; defaults to global shuffle if unset.
        val_shuffle (bool, optional): Shuffle flag for validation data.
        test_shuffle (bool, optional): Shuffle flag for test data.
        train_data_fraction (float, optional): Fraction of training data to use. Defaults to 1.0.
        val_data_fraction (float, optional): Fraction of validation data to use. Defaults to 1.0.
        test_data_fraction (float, optional): Fraction of test data to use. Defaults to 1.0.
        all_hdf5_data_path (str, optional): General HDF5 data path for all splits. If provided, overrides specific paths.
        resize (bool, optional): Whether to resize images. Defaults to False.
        resize_to (int or tuple, optional): Target size for resizing images.
        resize_interpolation (str, optional): Interpolation mode for resizing ('bilinear', 'bicubic', etc.).
        resize_antialiasing (bool, optional): Whether to apply antialiasing during resizing. Defaults to True.
        **kwargs: Additional keyword arguments.
    """
    self.prepare_data_per_node = False
    self._log_hyperparams = None
    self.allow_zero_length_dataloader_with_multiple_devices = False

    self.batch_size = batch_size
    self.num_workers = num_workers
    self.prefetch_factor = prefetch_factor

    self.train_hdf5_path = train_hdf5_path
    self.test_hdf5_path = test_hdf5_path
    self.val_hdf5_path = val_hdf5_path

    self.train_hdf5_keys_path = train_hdf5_keys_path
    self.test_hdf5_keys_path = test_hdf5_keys_path
    self.val_hdf5_keys_path = val_hdf5_keys_path

    if train_hdf5_path and not train_hdf5_keys_path: print(f"Train dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")
    if test_hdf5_path and not test_hdf5_keys_path: print(f"Test dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")
    if val_hdf5_path and not val_hdf5_keys_path: print(f"Val dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")

    self.train_hdf5_keys_save_path = kwargs.pop("train_hdf5_keys_save_path", None)
    self.test_hdf5_keys_save_path = kwargs.pop("test_hdf5_keys_save_path", None)
    self.val_hdf5_keys_save_path = kwargs.pop("val_hdf5_keys_save_path", None)

    self.shuffle = kwargs.pop("shuffle", None)
    self.train_shuffle = kwargs.pop("train_shuffle", None) or self.shuffle
    self.val_shuffle = kwargs.pop("val_shuffle", None)
    self.test_shuffle = kwargs.pop("test_shuffle", None)

    self.train_data_fraction = kwargs.pop("train_data_fraction", 1.0)
    self.val_data_fraction = kwargs.pop("val_data_fraction", 1.0)
    self.test_data_fraction = kwargs.pop("test_data_fraction", 1.0)

    if self.train_data_fraction != 1.0  and  not train_hdf5_keys_path: raise ValueError(f"train_data_fraction provided as non-unity but train_hdf5_keys_path is unset.")
    if self.val_data_fraction != 1.0  and  not val_hdf5_keys_path: raise ValueError(f"val_data_fraction provided as non-unity but val_hdf5_keys_path is unset.")
    if self.test_data_fraction != 1.0  and  not test_hdf5_keys_path: raise ValueError(f"test_data_fraction provided as non-unity but test_hdf5_keys_path is unset.")

    all_hdf5_data_path = kwargs.pop("all_hdf5_data_path", None)
    if all_hdf5_data_path is not None:
        print(f"all_hdf5_data_path provided, will be interpreted as the general data path for all splits.\nKeys in provided train_hdf5_keys_path assumed to encompass all keys for entire data. Validation and Test keys will be subtracted from Train keys.")
        if self.train_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific train_hdf5_path, remove the train_hdf5_path")
        if self.val_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific val_hdf5_path, remove the val_hdf5_path")
        if self.test_hdf5_path: raise ValueError(f"Both general all_hdf5_data_path provided and a specific test_hdf5_path, remove the test_hdf5_path")
        self.train_hdf5_path = all_hdf5_data_path
        self.val_hdf5_path = all_hdf5_data_path
        self.test_hdf5_path = all_hdf5_data_path
        self.reduce_train_keys = True
    else:
        self.reduce_train_keys = False

    self.resize = kwargs.pop("resize", False)
    self.resize_to = kwargs.pop("resize_to", None)
    if self.resize and self.resize_to is None:
        raise ValueError(f"Config provided resize as True, but resize_to parameter not given")
    self.resize_interpolation = kwargs.pop("resize_interpolation", None)
    if self.resize and self.resize_interpolation is None:
        print(f"Config provided resize as True, but resize_interpolation mode not given. Will assume default bilinear")
        self.resize_interpolation = "bilinear"
    interpolation_dict = {
        "bilinear": InterpolationMode.BILINEAR,
        "bicubic": InterpolationMode.BICUBIC,
        "nearest": InterpolationMode.NEAREST,
        "nearest_exact": InterpolationMode.NEAREST_EXACT
    }
    if self.resize:
        if self.resize_interpolation not in interpolation_dict.keys():
            raise ValueError(f"resize_interpolation provided as {self.resize_interpolation}, but valid options are: {interpolation_dict.keys()}")
        self.resize_interpolation = interpolation_dict[self.resize_interpolation]
    self.resize_antialiasing = kwargs.pop("resize_antialiasing", True)

    self.kwargs = kwargs
setup(stage)

Set up datasets.

Parameters:
  • stage (str) –

    Either fit, test.

Source code in terratorch/datamodules/sen4map.py
def setup(self, stage: str):
    """Set up datasets.

    Args:
        stage: Either fit, test.
    """
    if stage == "fit":
        train_keys = self._load_hdf5_keys_from_path(self.train_hdf5_keys_path, fraction=self.train_data_fraction)
        val_keys = self._load_hdf5_keys_from_path(self.val_hdf5_keys_path, fraction=self.val_data_fraction)
        if self.reduce_train_keys:
            test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
            train_keys = list(set(train_keys) - set(val_keys) - set(test_keys))
        train_file = h5py.File(self.train_hdf5_path, 'r')
        self.lucasS2_train = Sen4MapDatasetMonthlyComposites(
            train_file, 
            h5data_keys = train_keys, 
            resize = self.resize,
            resize_to = self.resize_to,
            resize_interpolation = self.resize_interpolation,
            resize_antialiasing = self.resize_antialiasing,
            save_keys_path = self.train_hdf5_keys_save_path,
            **self.kwargs
        )
        val_file = h5py.File(self.val_hdf5_path, 'r')
        self.lucasS2_val = Sen4MapDatasetMonthlyComposites(
            val_file, 
            h5data_keys=val_keys, 
            resize = self.resize,
            resize_to = self.resize_to,
            resize_interpolation = self.resize_interpolation,
            resize_antialiasing = self.resize_antialiasing,
            save_keys_path = self.val_hdf5_keys_save_path,
            **self.kwargs
        )
    if stage == "test":
        test_file = h5py.File(self.test_hdf5_path, 'r')
        test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
        self.lucasS2_test = Sen4MapDatasetMonthlyComposites(
            test_file, 
            h5data_keys=test_keys, 
            resize = self.resize,
            resize_to = self.resize_to,
            resize_interpolation = self.resize_interpolation,
            resize_antialiasing = self.resize_antialiasing,
            save_keys_path = self.test_hdf5_keys_save_path,
            **self.kwargs
        )

Transforms

The transforms module provides a set of specialized image transformations designed to manipulate spatial, temporal, and multimodal data efficiently. These transformations allow for greater flexibility when working with multi-temporal, multi-channel, and multi-modal datasets, ensuring that data can be formatted appropriately for different model architectures.

terratorch.datasets.transforms

FlattenSamplesIntoChannels

Bases: ImageOnlyTransform

FlattenSamplesIntoChannels is an image transformation that merges the sample (and optionally temporal) dimensions into the channel dimension.

This transform rearranges an input tensor by flattening the sample dimension, and if specified, also the temporal dimension, thereby concatenating these dimensions into a single channel dimension.

Source code in terratorch/datasets/transforms.py
class FlattenSamplesIntoChannels(ImageOnlyTransform):
    """
    FlattenSamplesIntoChannels is an image transformation that merges the sample (and optionally temporal) dimensions into the channel dimension.

    This transform rearranges an input tensor by flattening the sample dimension, and if specified, also the temporal dimension,
    thereby concatenating these dimensions into a single channel dimension.
    """
    def __init__(self, time_dim: bool = True):
        """
        Initialize the FlattenSamplesIntoChannels transform.

        Args:
            time_dim (bool): If True, the temporal dimension is included in the flattening process. Default is True.
        """
        super().__init__(True, 1)
        self.time_dim = time_dim

    def apply(self, img, **params):
        if self.time_dim:
            rearranged = rearrange(img,
                                   "samples time height width channels -> height width (samples time channels)")
        else:
            rearranged = rearrange(img, "samples height width channels -> height width (samples channels)")
        return rearranged

    def get_transform_init_args_names(self):
        return ()
__init__(time_dim=True)

Initialize the FlattenSamplesIntoChannels transform.

Parameters:
  • time_dim (bool, default: True ) –

    If True, the temporal dimension is included in the flattening process. Default is True.

Source code in terratorch/datasets/transforms.py
def __init__(self, time_dim: bool = True):
    """
    Initialize the FlattenSamplesIntoChannels transform.

    Args:
        time_dim (bool): If True, the temporal dimension is included in the flattening process. Default is True.
    """
    super().__init__(True, 1)
    self.time_dim = time_dim

FlattenTemporalIntoChannels

Bases: ImageOnlyTransform

FlattenTemporalIntoChannels is an image transformation that flattens the temporal dimension into the channel dimension.

This transform rearranges an input tensor with a temporal dimension into one where the time and channel dimensions are merged. It expects the input to have a fixed number of dimensions defined by N_DIMS_FOR_TEMPORAL.

Source code in terratorch/datasets/transforms.py
class FlattenTemporalIntoChannels(ImageOnlyTransform):
    """
    FlattenTemporalIntoChannels is an image transformation that flattens the temporal dimension into the channel dimension.

    This transform rearranges an input tensor with a temporal dimension into one where the time and channel dimensions
    are merged. It expects the input to have a fixed number of dimensions defined by N_DIMS_FOR_TEMPORAL.
    """
    def __init__(self):
        """
        Initialize the FlattenTemporalIntoChannels transform.
        """
        super().__init__(True, 1)

    def apply(self, img, **params):
        if len(img.shape) != N_DIMS_FOR_TEMPORAL:
            msg = f"Expected input temporal sequence to have {N_DIMS_FOR_TEMPORAL} dimensions, but got {len(img.shape)}"
            raise Exception(msg)
        rearranged = rearrange(img, "time height width channels -> height width (time channels)")
        return rearranged

    def get_transform_init_args_names(self):
        return ()
__init__()

Initialize the FlattenTemporalIntoChannels transform.

Source code in terratorch/datasets/transforms.py
def __init__(self):
    """
    Initialize the FlattenTemporalIntoChannels transform.
    """
    super().__init__(True, 1)

MultimodalTransforms

MultimodalTransforms applies albumentations transforms to multiple image modalities.

This class supports both shared transformations across modalities and separate transformations for each modality. It also handles non-image modalities by applying a specified non-image transform.

Source code in terratorch/datasets/transforms.py
class MultimodalTransforms:
    """
    MultimodalTransforms applies albumentations transforms to multiple image modalities.

    This class supports both shared transformations across modalities and separate transformations for each modality.
    It also handles non-image modalities by applying a specified non-image transform.
    """
    def __init__(
            self,
            transforms: dict | A.Compose,
            shared : bool = True,
            non_image_modalities: list[str] | None = None,
            non_image_transform: object | None = None,
    ):
        """
        Initialize the MultimodalTransforms.

        Args:
            transforms (dict or A.Compose): The transformation(s) to apply to the data.
            shared (bool): If True, the same transform is applied to all modalities; if False, separate transforms are used.
            non_image_modalities (list[str] | None): List of keys corresponding to non-image modalities.
            non_image_transform (object | None): A transform to apply to non-image modalities. If None, a default transform is used.
        """
        self.transforms = transforms
        self.shared = shared
        self.non_image_modalities = non_image_modalities
        self.non_image_transform = non_image_transform or default_non_image_transform

    def __call__(self, data: dict):
        if self.shared:
            # albumentations requires a key 'image' and treats all other keys as additional targets
            image_modality = list(set(data.keys()) - set(self.non_image_modalities))[0]
            data['image'] = data.pop(image_modality)
            data = self.transforms(**data)
            data[image_modality] = data.pop('image')

            # Process sequence data which is ignored by albumentations as 'global_label'
            for modality in self.non_image_modalities:
                data[modality] = self.non_image_transform(data[modality])
        else:
            # Applies transformations for each modality separate
            for key, value in data.items():
                data[key] = self.transforms[key](image=value)['image']  # Only works with image modalities

        return data
__init__(transforms, shared=True, non_image_modalities=None, non_image_transform=None)

Initialize the MultimodalTransforms.

Parameters:
  • transforms (dict or Compose) –

    The transformation(s) to apply to the data.

  • shared (bool, default: True ) –

    If True, the same transform is applied to all modalities; if False, separate transforms are used.

  • non_image_modalities (list[str] | None, default: None ) –

    List of keys corresponding to non-image modalities.

  • non_image_transform (object | None, default: None ) –

    A transform to apply to non-image modalities. If None, a default transform is used.

Source code in terratorch/datasets/transforms.py
def __init__(
        self,
        transforms: dict | A.Compose,
        shared : bool = True,
        non_image_modalities: list[str] | None = None,
        non_image_transform: object | None = None,
):
    """
    Initialize the MultimodalTransforms.

    Args:
        transforms (dict or A.Compose): The transformation(s) to apply to the data.
        shared (bool): If True, the same transform is applied to all modalities; if False, separate transforms are used.
        non_image_modalities (list[str] | None): List of keys corresponding to non-image modalities.
        non_image_transform (object | None): A transform to apply to non-image modalities. If None, a default transform is used.
    """
    self.transforms = transforms
    self.shared = shared
    self.non_image_modalities = non_image_modalities
    self.non_image_transform = non_image_transform or default_non_image_transform

Rearrange

Bases: ImageOnlyTransform

Rearrange is a generic image transformation that reshapes an input tensor using a custom einops pattern.

This transform allows flexible reordering of tensor dimensions based on the provided pattern and arguments.

Source code in terratorch/datasets/transforms.py
class Rearrange(ImageOnlyTransform):
    """
    Rearrange is a generic image transformation that reshapes an input tensor using a custom einops pattern.

    This transform allows flexible reordering of tensor dimensions based on the provided pattern and arguments.
    """

    def __init__(
        self, rearrange: str, rearrange_args: dict[str, int] | None = None, always_apply: bool = True, p: float = 1
    ):
        """
        Initialize the Rearrange transform.

        Args:
            rearrange (str): The einops rearrangement pattern to apply.
            rearrange_args (dict[str, int] | None): Additional arguments for the rearrangement pattern.
            always_apply (bool): Whether to always apply this transform. Default is True.
            p (float): The probability of applying the transform. Default is 1.
        """
        super().__init__(always_apply, p)
        self.rearrange = rearrange
        self.vars = rearrange_args if rearrange_args else {}

    def apply(self, img, **params):
        return rearrange(img, self.rearrange, **self.vars)

    def get_transform_init_args_names(self):
        return "rearrange"
__init__(rearrange, rearrange_args=None, always_apply=True, p=1)

Initialize the Rearrange transform.

Parameters:
  • rearrange (str) –

    The einops rearrangement pattern to apply.

  • rearrange_args (dict[str, int] | None, default: None ) –

    Additional arguments for the rearrangement pattern.

  • always_apply (bool, default: True ) –

    Whether to always apply this transform. Default is True.

  • p (float, default: 1 ) –

    The probability of applying the transform. Default is 1.

Source code in terratorch/datasets/transforms.py
def __init__(
    self, rearrange: str, rearrange_args: dict[str, int] | None = None, always_apply: bool = True, p: float = 1
):
    """
    Initialize the Rearrange transform.

    Args:
        rearrange (str): The einops rearrangement pattern to apply.
        rearrange_args (dict[str, int] | None): Additional arguments for the rearrangement pattern.
        always_apply (bool): Whether to always apply this transform. Default is True.
        p (float): The probability of applying the transform. Default is 1.
    """
    super().__init__(always_apply, p)
    self.rearrange = rearrange
    self.vars = rearrange_args if rearrange_args else {}

SelectBands

Bases: ImageOnlyTransform

SelectBands is an image transformation that selects a subset of bands (channels) from an input image.

This transform uses specified band indices to filter and output only the desired channels from the image tensor.

Source code in terratorch/datasets/transforms.py
class SelectBands(ImageOnlyTransform):
    """
    SelectBands is an image transformation that selects a subset of bands (channels) from an input image.

    This transform uses specified band indices to filter and output only the desired channels from the image tensor.
    """

    def __init__(self, band_indices: list[int]):
        """
        Initialize the SelectBands transform.

        Args:
            band_indices (list[int]): A list of indices specifying which bands to select.
        """
        super().__init__(True, 1)
        self.band_indices = band_indices

    def apply(self, img, **params):
        return img[..., self.band_indices]

    def get_transform_init_args_names(self):
        return "band_indices"
__init__(band_indices)

Initialize the SelectBands transform.

Parameters:
  • band_indices (list[int]) –

    A list of indices specifying which bands to select.

Source code in terratorch/datasets/transforms.py
def __init__(self, band_indices: list[int]):
    """
    Initialize the SelectBands transform.

    Args:
        band_indices (list[int]): A list of indices specifying which bands to select.
    """
    super().__init__(True, 1)
    self.band_indices = band_indices

UnflattenSamplesFromChannels

Bases: ImageOnlyTransform

UnflattenSamplesFromChannels is an image transformation that restores the sample (and optionally temporal) dimensions from the channel dimension.

This transform is designed to reverse the flattening performed by FlattenSamplesIntoChannels and is typically applied after converting images to a channels-first format.

Source code in terratorch/datasets/transforms.py
class UnflattenSamplesFromChannels(ImageOnlyTransform):
    """
    UnflattenSamplesFromChannels is an image transformation that restores the sample (and optionally temporal) dimensions from the channel dimension.

    This transform is designed to reverse the flattening performed by FlattenSamplesIntoChannels and is typically applied
    after converting images to a channels-first format.
    """
    def __init__(
            self,
            time_dim: bool = True,
            n_samples: int | None = None,
            n_timesteps: int | None = None,
            n_channels: int | None = None
    ):
        """
        Initialize the UnflattenSamplesFromChannels transform.

        Args:
            time_dim (bool): If True, the temporal dimension is considered during unflattening.
            n_samples (int | None): The number of samples.
            n_timesteps (int | None): The number of time steps.
            n_channels (int | None): The number of channels per time step.

        Raises:
            Exception: If time_dim is True and fewer than two of n_channels, n_timesteps, and n_samples are provided.
            Exception: If time_dim is False and neither n_channels nor n_samples is provided.
        """
        super().__init__(True, 1)

        self.time_dim = time_dim
        if self.time_dim:
            if bool(n_channels) + bool(n_timesteps) + bool(n_samples) < 2:
                msg = "Two of n_channels, n_timesteps, and n_channels must be provided"
                raise Exception(msg)
            if n_timesteps and n_channels:
                self.additional_info = {"channels": n_channels, "time": n_timesteps}
            elif n_timesteps and n_samples:
                self.additional_info = {"time": n_timesteps, "samples": n_samples}
            else:
                self.additional_info = {"channels": n_channels, "samples": n_samples}
        else:
            if n_channels is None and n_samples is None:
                msg = "One of n_channels or n_samples must be provided"
                raise Exception(msg)
            self.additional_info = {"channels": n_channels} if n_channels else {"samples": n_samples}

    def apply(self, img, **params):
        if self.time_dim:
            rearranged = rearrange(
                img, "(samples time channels) height width -> samples channels time height width",
                **self.additional_info
            )
        else:
            rearranged = rearrange(
                img, "(samples channels) height width -> samples channels height width", **self.additional_info
            )
        return rearranged

    def get_transform_init_args_names(self):
        return ("n_timesteps", "n_channels")
__init__(time_dim=True, n_samples=None, n_timesteps=None, n_channels=None)

Initialize the UnflattenSamplesFromChannels transform.

Parameters:
  • time_dim (bool, default: True ) –

    If True, the temporal dimension is considered during unflattening.

  • n_samples (int | None, default: None ) –

    The number of samples.

  • n_timesteps (int | None, default: None ) –

    The number of time steps.

  • n_channels (int | None, default: None ) –

    The number of channels per time step.

Raises:
  • Exception

    If time_dim is True and fewer than two of n_channels, n_timesteps, and n_samples are provided.

  • Exception

    If time_dim is False and neither n_channels nor n_samples is provided.

Source code in terratorch/datasets/transforms.py
def __init__(
        self,
        time_dim: bool = True,
        n_samples: int | None = None,
        n_timesteps: int | None = None,
        n_channels: int | None = None
):
    """
    Initialize the UnflattenSamplesFromChannels transform.

    Args:
        time_dim (bool): If True, the temporal dimension is considered during unflattening.
        n_samples (int | None): The number of samples.
        n_timesteps (int | None): The number of time steps.
        n_channels (int | None): The number of channels per time step.

    Raises:
        Exception: If time_dim is True and fewer than two of n_channels, n_timesteps, and n_samples are provided.
        Exception: If time_dim is False and neither n_channels nor n_samples is provided.
    """
    super().__init__(True, 1)

    self.time_dim = time_dim
    if self.time_dim:
        if bool(n_channels) + bool(n_timesteps) + bool(n_samples) < 2:
            msg = "Two of n_channels, n_timesteps, and n_channels must be provided"
            raise Exception(msg)
        if n_timesteps and n_channels:
            self.additional_info = {"channels": n_channels, "time": n_timesteps}
        elif n_timesteps and n_samples:
            self.additional_info = {"time": n_timesteps, "samples": n_samples}
        else:
            self.additional_info = {"channels": n_channels, "samples": n_samples}
    else:
        if n_channels is None and n_samples is None:
            msg = "One of n_channels or n_samples must be provided"
            raise Exception(msg)
        self.additional_info = {"channels": n_channels} if n_channels else {"samples": n_samples}

UnflattenTemporalFromChannels

Bases: ImageOnlyTransform

UnflattenTemporalFromChannels is an image transformation that restores the temporal dimension from the channel dimension.

This transform is typically applied after converting images to a channels-first format (e.g., after ToTensorV2) and rearranges the flattened temporal information back into separate time and channel dimensions.

Source code in terratorch/datasets/transforms.py
class UnflattenTemporalFromChannels(ImageOnlyTransform):
    """
    UnflattenTemporalFromChannels is an image transformation that restores the temporal dimension from the channel dimension.

    This transform is typically applied after converting images to a channels-first format (e.g., after ToTensorV2)
    and rearranges the flattened temporal information back into separate time and channel dimensions.
    """

    def __init__(self, n_timesteps: int | None = None, n_channels: int | None = None):
        super().__init__(True, 1)
        """
        Initialize the UnflattenTemporalFromChannels transform.

        Args:
            n_timesteps (int | None): The number of time steps. Must be provided if n_channels is not provided.
            n_channels (int | None): The number of channels per time step. Must be provided if n_timesteps is not provided.

        Raises:
            Exception: If neither n_timesteps nor n_channels is provided.
        """
        if n_timesteps is None and n_channels is None:
            msg = "One of n_timesteps or n_channels must be provided"
            raise Exception(msg)
        self.additional_info = {"channels": n_channels} if n_channels else {"time": n_timesteps}

    def apply(self, img, **params):
        if len(img.shape) != N_DIMS_FLATTENED_TEMPORAL:
            msg = f"Expected input temporal sequence to have {N_DIMS_FLATTENED_TEMPORAL} dimensions\
                , but got {len(img.shape)}"
            raise Exception(msg)

        rearranged = rearrange(
            img, "(time channels) height width -> channels time height width", **self.additional_info
        )
        return rearranged

    def get_transform_init_args_names(self):
        return ("n_timesteps", "n_channels")